王泽宇, NVIDIA GPU加速计算专家团队 高级工程师 | November 7, 2025
DSA 核优化 (DSA Kernel Optimization)
MLA 反向核优化 (MLA Backward Kernel Optimization)
未来工作 (Future Works)
DeepSeek V3.2 是一个新的“实验性”模型,旨在提升长上下文效率。其基准测试性能和 API 定价如下:
该模型的一个关键特性是稀疏注意力(Sparse Attention)。
DSA 的核心思想是在推理过程中仅选择 TopK 个 KV Token 进行注意力计算,以降低延迟。
标准的注意力机制(如 Multi-Head Attention, MHA)需要新的查询(Q)Token 与所有的键值(KV)Token 进行计算。
DSA 引入了一个 TopK 索引器(TopK Indexer),它会从大量的 KV Token 中筛选出最相关的 TopK 个。
通过这种方式,注意力核(Attention Kernel)只需在选定的 TopK KV Token 子集上进行计算,从而显著减少计算量和延迟。与传统的 MLA(Multi-Layer Attention)相比,DSA 的延迟大幅降低。本次演讲的重点在于对此过程中的注意力核进行优化。
DSA 可以被理解为带有稀疏性的 MLA (Multi-Query Attention)。下图对比了 MHA (Multi-Head Attention) 和 MQA (Multi-Query Attention) 的张量结构。
MHA (Prefill 阶段)
MQA (Decoding 阶段)
在讨论 DSA 的挑战之前,先回顾一下之前在 Hopper 架构上使用 FMHA_V2 对 MLA 注意力核的优化经验。
挑战:
64 x 512 / 128 = 256 regs/thread。有利条件:
解决方案:
V = K[:, :512]。下图展示了该优化方案的流水线操作:
该优化带来了显著的性能提升,如下方的 TFLOPS 性能图所示,FMHA Opt 的性能远超其他方法。
上图展示了DeepSeek稀疏注意力(DSA)中FP8 KV(键值)缓存的内存布局。每个令牌(token)占用656字节。
512 e4m3(512字节)用于存储没有位置编码(NoPE)的KV缓存,其构成是 512xFP8 + 4xFP32。这里的 4xFP32 是缩放因子(Scaling factor)。64xbf16。在 MLA 优化的基础上,DSA 引入了新的挑战:
稀疏性处理 (Sparsity):
预填充 (Prefill) 阶段的挑战 (在长上下文场景中):
解码 (Decoding) 阶段的挑战:
本节介绍NVIDIA Blackwell平台的概览。
张量内存(Tensor Memory)
matrix_A和matrix_O。tcgen05.ld/sttcgen05.cp
上图展示了张量内存的布局和寻址方式,它由128个Lane组成,每个Lane为2KB。
2-CTA Tensorcore GEMM
上图展示了2-CTA GEMM的概念,其中A和B矩阵被分配到两个不同的CTA(CTA0和CTA1)中进行计算。
TMA.Gather4
int4 idx输入,用于指定需要收集的数据索引。
上图(图10)展示了tiled::scatter4/tiled::gather4模式下边界框(bounding box)的一个示例,说明了如何根据不同的起始坐标从全局内存中收集(gather)数据到共享内存。
本节将初步探讨DSA在预填充阶段的实现。
QKV 配置:
核函数设置 (Kernel Setup):
2-CTA集群 (2-CTA Cluster): 采用两个协作线程块(CTA)组成的集群。
图中提出了一个关键问题:“如何在2-SM GEMM中处理V?”
架构图:
共享内存 (Shared Memory):
张量内存 (Tensor Memory):
Q Offloading to TMEM: Q矩阵被卸载到张量内存中。
流水线 (Pipeline):
上图展示了TMA、Tensorcore和CUDA Core之间的并行执行流水线。TMA负责加载数据,Tensorcore执行矩阵乘法(QK, SV),CUDA Core执行数学函数(MUFU)和缩放(Scale)。
线程配置:
基准测试 (Benchmark):
本节介绍DSA在解码阶段的实现。
Q_LEN = 1,如果启用了多令牌并行(MTP),则可能大于1。使用splitKV来平衡SM的工作负载。反量化 (Dequantization) of FP8:
GEMM:
解决方案: 使用带有分布式共享内存(DSMEM)的2-CTA对 (2-CTA pair with DSMEM)。
该方案通过2-CTA集群和分布式共享内存(DSMEM)来解决反量化带来的高CUDA核心压力。
* 工作原理: 每个CTA负责一半的反量化工作,并通过DSMEM进行多播(Multi-cast)。
* 流程: 512字节的e4m3数据被分成两部分,分别加载到两个CTA的本地共享内存(SHM0的FP8-CTA0和SHM1的FP8-CTA1)中。经过反量化后,结果通过DSMEM共享到2-CTA集群的共享内存中,形成BF16-CTA0和BF16-CTA1。
* 优势:
* 两个CTA共享相同的TopK KV令牌。
* 将CUDA核心压力减半,降至 1664 CLK/CTA。
共享内存 (Shared Memory) (双缓冲):
张量内存 (Tensor Memory):
理想流水线 (Ideal Pipeline):
上图展示了稀疏解码阶段的理想化流水线,显示了TMA、Tensorcore和CUDA Core(执行反量化、MUFU、Scale)之间的并行工作流。
CTA维度: [2, q_len, sm_parts]
线程配置:
当前进展 (Current Progress):
下图展示了一种假设情况的流水线,即如果从 FP8 到 BF16 的转换可以由一条单一指令完成。在这种优化下,原本在 CUDA Core 上执行的多个反量化(Dequantization)、乘法融合(MUFU)和缩放(Scale)操作可以被整合,从而简化执行流程,提高效率。TMA(Tensor Memory Accelerator)负责加载数据,Tensor Core 执行核心的矩阵运算,而 CUDA Core 的负担减轻。
标准注意力机制的前向传播过程如下图所示。它主要由两个通用矩阵乘法(GEMM)操作和一个 Softmax 操作组成:
反向传播过程计算输出 O 对输入 Q、K、V 的梯度(分别为 dO、dQ、dK、dV)。梯度流与前向传播的计算图方向相反。
反向传播的具体计算公式如下:
* $P = Q * K^T$
* $S = Softmax(P) = exp(P - lse)$
* $dV = S^T * dO$
* $dS = dO * V^T$
* $dP = S \circ (dS - sum(O \circ dO))$ (其中 $\circ$ 表示逐元素相乘)
* $dQ = dP * K$
* $dK = dP^T * Q$
从计算角度看,反向传播过程主要包含 5 个 GEMM 操作和 2 个由 CUDA Core 执行的操作。
注意力反向传播的计算流程涉及以下几个关键步骤。
为了优化计算,一些中间值可以预先计算或在前向传播时计算并保存下来:
* lse (log-sum-exp) 在前向传播时计算。
* sum(O ◦ dO) 可以在反向传播主循环开始前预先计算。
在实现核函数时,循环的顺序是一个关键的设计选择。两种常见的策略是:
1. 外层 KV,内层 QO: 外层循环遍历 KV 的分块(tile),内层循环遍历 Q 的分块。在内层核函数中累加 dQ。
2. 外层 QO,内层 KV: 外层循环遍历 Q 的分块,内层循环遍历 KV 的分块。在内层核函数中累加 dK 和 dV。
下图展示了注意力反向传播的数据流。首先计算 sumOdO,然后将其与 dO, Q, K, V 一同输入到主计算模块 Backward Attn 中,得到 dK 和 dV,并累加生成最终的 dQ。
本节将讨论在 Blackwell 架构上实现反向注意力的具体细节。
传统的注意力反向核函数实现中,内存布局和流水线设计如下:
配置参数:
资源分配:
下图展示了在共享内存(Shared Memory)和张量内存(Tensor Memory)中的数据布局。
其计算流水线大致如下,TMA 负责加载数据,Tensor Core 和 CUDA Core 交替执行计算。注意,此图可能不完全代表真实的流水线。
对于 MLA,实现上存在一些差异和挑战。
QK_head_dim 从 128 增加到 192。Q_STEP 和 KV_STEP 仍为 128 时,由于 head_dim 增大,会导致内存溢出。Q_Tile_STEP 从 128 减小到 64。这一调整改变了共享内存和张量内存中与 Q 相关的张量(如 Q, dO, S, dQ)的分块大小,从而适应了更大的 head_dim,避免了内存问题。下图展示了调整后的内存布局。
性能: 达到默认反向核(d_qk=d_vo=128)性能的 85%。
基准测试结果
当数据分布不均衡时,会出现性能下降问题。例如,一个批次(Batch 0)包含大量数据,而其他批次(Batch 1-7)数据量很小。这种不均衡会导致性能下降至 300 TFLOPS。
在 ComputeSumOdO 计算中,对于非均衡数据 Data = [10000] + 99*[1],延迟高达 10.3ms,而对于均衡数据 Data = [100]*100,延迟仅为 0.382ms。
GridDim: [ceil(q_max_len / Q_BLOCK), head_num, batch_size]BlockDim: [8, 16, 1]threadIdx.y 的 threadIdx.x 计算一个 Sum(OdO)。GridDim: [ceil(total_q_len / Q_BLOCK), head_num]q_idx 在总长度中的位置来确定其 batchID,使用二分搜索(Binary Search)在 cu_q_len[] 中查找。优化后的内核在循环的第一次迭代中使用二分搜索确定 bs_id,在后续迭代中如果 q_idx 超出了当前批次的长度,则查找新的批次。
优化后,ComputeSumOdO 在处理非均衡数据 Data: [10000] + 99*[1] 时的延迟从基线的 10.3ms 显著降低到 0.26ms。相关的 PR 将于本月提交至 FlashMLA。