Best Practice of MLA Kernel Optimization on Blackwell

王泽宇, NVIDIA GPU加速计算专家团队 高级工程师 | November 7, 2025

目录 (Agenda)

DSA 核优化 (DSA Kernel Optimization)

DeepSeek V3.2 介绍

DeepSeek V3.2 是一个新的“实验性”模型,旨在提升长上下文效率。其基准测试性能和 API 定价如下:

Page 5, DeepSeek V3.2 在各项基准测试中的表现,以及其API的定价信息。
Page 5, DeepSeek V3.2 在各项基准测试中的表现,以及其API的定价信息。

该模型的一个关键特性是稀疏注意力(Sparse Attention)。

DSA: DeepSeek 稀疏注意力机制 (DeepSeek Sparse Attention)

DSA 的核心思想是在推理过程中仅选择 TopK 个 KV Token 进行注意力计算,以降低延迟。

标准的注意力机制(如 Multi-Head Attention, MHA)需要新的查询(Q)Token 与所有的键值(KV)Token 进行计算。

Page 6, 标准注意力机制示意图,新的Q Token与所有KV Tokens进行计算。
Page 6, 标准注意力机制示意图,新的Q Token与所有KV Tokens进行计算。

DSA 引入了一个 TopK 索引器(TopK Indexer),它会从大量的 KV Token 中筛选出最相关的 TopK 个。

Page 8, DSA示意图,TopK索引器从KV Tokens中选择一部分。
Page 8, DSA示意图,TopK索引器从KV Tokens中选择一部分。

通过这种方式,注意力核(Attention Kernel)只需在选定的 TopK KV Token 子集上进行计算,从而显著减少计算量和延迟。与传统的 MLA(Multi-Layer Attention)相比,DSA 的延迟大幅降低。本次演讲的重点在于对此过程中的注意力核进行优化。

Page 10, DSA机制完整流程图,展示了TopK索引器选择KV Tokens子集,以及与MLA相比的延迟优势,并强调了对注意力核的优化。
Page 10, DSA机制完整流程图,展示了TopK索引器选择KV Tokens子集,以及与MLA相比的延迟优势,并强调了对注意力核的优化。

DSA 核的挑战 (Challenges of DSA Kernel)

DSA 本质是带稀疏性的 MLA

DSA 可以被理解为带有稀疏性的 MLA (Multi-Query Attention)。下图对比了 MHA (Multi-Head Attention) 和 MQA (Multi-Query Attention) 的张量结构。

Page 12, MHA 和 MQA 的张量结构对比图。
Page 12, MHA 和 MQA 的张量结构对比图。

回顾:MLA 注意力核 (基于 Hopper 的 FMHA_V2 优化)

在讨论 DSA 的挑战之前,先回顾一下之前在 Hopper 架构上使用 FMHA_V2 对 MLA 注意力核的优化经验。

下图展示了该优化方案的流水线操作:
Page 13, MLA 注意力核在 Hopper 上的优化流水线示意图。

该优化带来了显著的性能提升,如下方的 TFLOPS 性能图所示,FMHA Opt 的性能远超其他方法。
Page 14, 在 H100 80G HBM3 上的 TFLOPS 性能对比图,批大小为128。

FP8 KV 缓存布局

Page 16
Page 16

上图展示了DeepSeek稀疏注意力(DSA)中FP8 KV(键值)缓存的内存布局。每个令牌(token)占用656字节。

DSA 带来的额外挑战

在 MLA 优化的基础上,DSA 引入了新的挑战:

Page 15, DSA 额外挑战的总结,并附有一个FP8数据块的示意图。
Page 17

Blackwell 平台概述

Page 18
Page 18

本节介绍NVIDIA Blackwell平台的概览。

Blackwell平台的优势与机遇

初尝:DSA 稀疏预填充(Sparse Prefill)

Page 22
Page 22

本节将初步探讨DSA在预填充阶段的实现。

DSA稀疏预填充概述

Page 23
Page 23

DSA稀疏预填充内存布局

Page 24
Page 24

DSA稀疏预填充流水线与基准测试

Page 25
Page 25

第二项任务:DSA 稀疏解码(Sparse Decoding)

Page 26
Page 26

本节介绍DSA在解码阶段的实现。

DSA稀疏解码的挑战:FP8布局与反量化

Page 27
Page 27

稀疏解码中的2-CTA反量化

Page 28
Page 28

该方案通过2-CTA集群和分布式共享内存(DSMEM)来解决反量化带来的高CUDA核心压力。
* 工作原理: 每个CTA负责一半的反量化工作,并通过DSMEM进行多播(Multi-cast)。
* 流程: 512字节的e4m3数据被分成两部分,分别加载到两个CTA的本地共享内存(SHM0的FP8-CTA0和SHM1的FP8-CTA1)中。经过反量化后,结果通过DSMEM共享到2-CTA集群的共享内存中,形成BF16-CTA0和BF16-CTA1。
* 优势:
* 两个CTA共享相同的TopK KV令牌。
* 将CUDA核心压力减半,降至 1664 CLK/CTA

DSA稀疏解码内存布局

Page 29
Page 29

DSA稀疏解码 - 进行中:流水线与当前进展

Page 30
Page 30

假设:FP8->BF16 是一条单一指令

下图展示了一种假设情况的流水线,即如果从 FP8 到 BF16 的转换可以由一条单一指令完成。在这种优化下,原本在 CUDA Core 上执行的多个反量化(Dequantization)、乘法融合(MUFU)和缩放(Scale)操作可以被整合,从而简化执行流程,提高效率。TMA(Tensor Memory Accelerator)负责加载数据,Tensor Core 执行核心的矩阵运算,而 CUDA Core 的负担减轻。

Page 31 展示了将 FP8 到 BF16 转换作为单一指令的潜在流水线优化
Page 31 展示了将 FP8 到 BF16 转换作为单一指令的潜在流水线优化

MLA 反向核优化 (MLA Backward Kernel Optimization)

Page 32 标题页:MLA Backward Kernel Implementation
Page 32 标题页:MLA Backward Kernel Implementation

背景:注意力的前向与反向传播

前向传播

标准注意力机制的前向传播过程如下图所示。它主要由两个通用矩阵乘法(GEMM)操作和一个 Softmax 操作组成:

  1. 查询(Q)与键的转置(Kᵀ)进行矩阵相乘,得到分数矩阵 P。
  2. 对 P 应用 Softmax 函数,得到注意力权重矩阵 S。
  3. S 与值(V)进行矩阵相乘,得到最终输出 O。
Page 33 注意力机制前向传播流程图
Page 33 注意力机制前向传播流程图

反向传播

反向传播过程计算输出 O 对输入 Q、K、V 的梯度(分别为 dO、dQ、dK、dV)。梯度流与前向传播的计算图方向相反。

Page 34 注意力机制反向传播的梯度流图
Page 34 注意力机制反向传播的梯度流图

反向传播的具体计算公式如下:
* $P = Q * K^T$
* $S = Softmax(P) = exp(P - lse)$
* $dV = S^T * dO$
* $dS = dO * V^T$
* $dP = S \circ (dS - sum(O \circ dO))$ (其中 $\circ$ 表示逐元素相乘)
* $dQ = dP * K$
* $dK = dP^T * Q$

Page 35 注意力机制反向传播的计算公式
Page 35 注意力机制反向传播的计算公式

从计算角度看,反向传播过程主要包含 5 个 GEMM 操作和 2 个由 CUDA Core 执行的操作。

Page 36 注意力机制反向传播的计算开销分析
Page 36 注意力机制反向传播的计算开销分析

注意力反向核函数流程

注意力反向传播的计算流程涉及以下几个关键步骤。

Page 37 注意力反向传播计算公式列表
Page 37 注意力反向传播计算公式列表

为了优化计算,一些中间值可以预先计算或在前向传播时计算并保存下来:
* lse (log-sum-exp) 在前向传播时计算。
* sum(O ◦ dO) 可以在反向传播主循环开始前预先计算。

Page 38 反向传播中的预计算与前向计算值
Page 38 反向传播中的预计算与前向计算值

在实现核函数时,循环的顺序是一个关键的设计选择。两种常见的策略是:
1. 外层 KV,内层 QO: 外层循环遍历 KV 的分块(tile),内层循环遍历 Q 的分块。在内层核函数中累加 dQ
2. 外层 QO,内层 KV: 外层循环遍历 Q 的分块,内层循环遍历 KV 的分块。在内层核函数中累加 dKdV

Page 39 注意力反向传播的两种循环策略
Page 39 注意力反向传播的两种循环策略

下图展示了注意力反向传播的数据流。首先计算 sumOdO,然后将其与 dO, Q, K, V 一同输入到主计算模块 Backward Attn 中,得到 dKdV,并累加生成最终的 dQ

Page 40 注意力反向传播的数据流图
Page 40 注意力反向传播的数据流图

Blackwell 架构上的反向注意力

本节将讨论在 Blackwell 架构上实现反向注意力的具体细节。

Page 41 标题页:Backward Attention in Blackwell
Page 41 标题页:Backward Attention in Blackwell

传统注意力反向核函数

传统的注意力反向核函数实现中,内存布局和流水线设计如下:

下图展示了在共享内存(Shared Memory)和张量内存(Tensor Memory)中的数据布局。

Page 42 传统注意力反向核函数的内存布局
Page 42 传统注意力反向核函数的内存布局

其计算流水线大致如下,TMA 负责加载数据,Tensor Core 和 CUDA Core 交替执行计算。注意,此图可能不完全代表真实的流水线。

Page 43 传统注意力反向核函数的计算流水线示意图
Page 43 传统注意力反向核函数的计算流水线示意图

MLA 注意力反向核函数

对于 MLA,实现上存在一些差异和挑战。

Page 44 MLA 反向核函数面临的挑战及内存布局
Page 44 MLA 反向核函数面临的挑战及内存布局

这一调整改变了共享内存和张量内存中与 Q 相关的张量(如 Q, dO, S, dQ)的分块大小,从而适应了更大的 head_dim,避免了内存问题。下图展示了调整后的内存布局。

Page 45 针对 MLA 调整 Q_Tile_STEP 后的解决方案及内存布局
Page 45 针对 MLA 调整 Q_Tile_STEP 后的解决方案及内存布局

针对 MLA 的优化

Attention 反向核 (MLA) 基准测试

Page 46
Page 46

针对反向传播中非均衡数据的优化

反向 Attention 中的非均衡数据

当数据分布不均衡时,会出现性能下降问题。例如,一个批次(Batch 0)包含大量数据,而其他批次(Batch 1-7)数据量很小。这种不均衡会导致性能下降至 300 TFLOPS。

ComputeSumOdO 计算中,对于非均衡数据 Data = [10000] + 99*[1],延迟高达 10.3ms,而对于均衡数据 Data = [100]*100,延迟仅为 0.382ms。

Page 48
Page 48

ComputeSumOdO 核布局

原始布局
目标布局(优化后)
Page 49
Page 49

优化后的 ComputeSumOdO 核布局

优化后的内核在循环的第一次迭代中使用二分搜索确定 bs_id,在后续迭代中如果 q_idx 超出了当前批次的长度,则查找新的批次。

Page 50
Page 50

性能

优化后,ComputeSumOdO 在处理非均衡数据 Data: [10000] + 99*[1] 时的延迟从基线的 10.3ms 显著降低到 0.26ms。相关的 PR 将于本月提交至 FlashMLA。

Page 51
Page 51

未来工作 (Future Works)

谢谢!