Jianxing Xu, Yuanbo Wen, Jun Bi, Ruibai Xu, Guanglin Xu, Rui Zhang, Wei Li, Ling Li, Tianshi Chen, Qi Guo, Yunji Chen
University of Science and Technology of China; Institute of Computing Technology, Chinese Academy of Sciences; Cambricon Technologies.
本文针对现有的融合注意力机制(Fused Attention,如 FlashAttention-2/3)中存在的向量间隔(Vector Interval)瓶颈问题,提出了一种向全张量化注意力机制演进的方案 FlashAttention-T。主要贡献归纳如下:
融合注意力的算法流程
融合注意力通过将计算并行化到 GPU 内核中的 Warpgroup(或 Threadblock)来执行,每个 Warpgroup 计算输出矩阵 $O$ 的一个行块。如 Fig 2(a) 所示,算法流程包含:❶ 计算注意力 Logit 块 $S$;❷ 更新行最大值 $m$;❸❺❻ 基于 $m$ 和旧的行最大值 $m_{old}$ 对 $S$ 进行指数化生成注意力分数块 $\tilde{P}$;❹ 重新缩放注意力输出 $O$;❼ 更新注意力分数行和 $l$;❽ 将结果累加到 $O$。这种分块方案确保中间结果驻留在寄存器中,为利用张量 MMA 指令进行张量化提供了基础。
Softmax 时间占比分析
在当前的范式中,Softmax 操作(max, fma, mul, add, exp)使用向量单元,而 GEMM 使用张量单元。对于序列长度 $N$ 和头维度 $h$,Softmax 的总操作数为 $4N^2 + \frac{N-1}{N}N^2h$,而 GEMM 为 $2N^2h$ FMAs。随着硬件发展(如 Hopper H100 相比 Ampere A100 吞吐量翻倍),张量单元与向量单元的吞吐量比率 $P_T/P_V$ 增加,导致 Softmax 的时间占比显著上升,成为计算瓶颈。
向量间隔瓶颈的揭示
通过分析 FlashAttention-2(Ampere 优化)和 FlashAttention-3(Hopper 优化)的执行时间线(见 Fig 3),作者揭示了向量间隔瓶颈:
* FlashAttention-2 (Fig 3a):采用顺序调度,GEMM 计算后必须等待向量单元完成 Softmax 才能进行下一次 GEMM,导致张量单元空闲时间 $t_{vec}$。
* FlashAttention-3 (Fig 3b):利用 Hopper 的异步 WGMMA 实现流水线调度,部分重叠了 GEMM 和 Softmax。然而,未重叠的 Softmax 计算仍导致显著的向量间隔。
* Table 1 的 Profiling 结果显示,向量间隔在 FlashAttention-2 中平均占迭代周期的 29.8%,在 FlashAttention-3(FP8-FP32)中升至 36.3%。
重新利用张量 MMA 指令
操作数赋值原理与映射
现代 GPU 的张量 MMA 指令将逻辑操作数矩阵划分为分布在线程间的片段(fragments)。定义 $v[i, j]$ 为逻辑矩阵 $v$ 在 $i$ 行 $j$ 列的元素,$v(u, t)$ 为线程 $u$ 持有的片段中的第 $t$ 个元素。映射函数 $\phi_v: v[i, j] \to v(u, t)$ 是特定于架构和操作数类型的(如图 Fig 4 所示)。为了重新利用 MMA 指令执行 Softmax 原语,需要确定输入片段的赋值,使得 MMA 输出片段 $D(v, t)$ 等于输入片段 $A$ 的置换后的逐元素缩放。
逐元素缩放与融合乘加(Scaling & FMA)
对于缩放因子 $\alpha$,目标是确定片段 $B$ 的赋值,使得 MMA 计算出的输出片段 $D$ 满足 $D(v, t) = \alpha A(\sigma(v), t)$,其中 $\sigma$ 是由张量单元内部片段映射引起的不可避免的置换(Permutation)。
* 置换开销最小化:为了恢复非置换输出(即 $D(v, t) = \alpha A(v, t)$),需要在线程寄存器内进行交换(swap)操作,这会产生开销。作者将此问题建模为约束优化问题:最小化置换 $\sigma$ 的 Cayley 距离(即将其转换为恒等置换所需的最小交换次数)。
* 最优赋值:Fig 5(a) 展示了最优的片段 $B$ 赋值方法,对于 $\sigma = (1\ 2)$,只需一次交换即可恢复。
* 融合乘加扩展:基于此,通过利用张量单元的累加器添加偏移量 $C(v, t)$,该方法可扩展为实现 $D(v, t) = \alpha A(\sigma(v), t) + C(v, t)$ 的融合乘加操作。
行和归约(Row-Sum Reduction)
行和归约涉及对注意力分数矩阵 $\tilde{P}$ 的行进行求和。传统的向量化方法需要线程内加法和显式的线程间 All-Reduce,带来同步和通信开销。作者提出利用张量 MMA 指令更高效地执行此操作:
* 输入构造:设置输入片段 $A$ 来自矩阵 $\tilde{P}$,将片段 $B$ 设置为如 Fig 5(b) 所示的特定模式,并将累加器 $C$ 置零。
* 输出计算:MMA 输出片段 $D'$ 本质上包含了部分和。通过执行两次线程内加法 $s(0, t) = D'(0, t) + D'(2, t)$ 和 $s(1, t) = D'(1, t) + D'(3, t)$,即可得到最终的归约结果。
* 优势:这种方法避免了显式的线程同步和通信,且置换 $\sigma$ 被吸收到线程内加法中,不产生额外的交换开销。
讨论
这些方法不需要额外的数据拷贝开销,因为它们直接在寄存器中操作。主要限制是算法约束:缩放和 FMA 要求 MMA 指令的所有行共享统一的缩放因子 $\alpha$。这需要后续的张量化算法来配合。
张量化在线 Softmax 算法
统一缩放因子的挑战与代理最大值
原始融合注意力算法中的输出重缩放步骤 $O = \exp(m_{old} - m) \cdot O$ 需要每行的缩放因子 $\exp(m_{old} - m)$,这与张量 MMA 缩放指令要求的统一缩放因子冲突。
* 代理最大值 $\hat{m}$:为了解决此冲突,作者提出了代理最大值(Surrogate Maximum) $\hat{m}[i]$,定义为包含第 $i$ 行的 $X$ 行图块(tile)内的最大值:$\hat{m}[i] = \max_{k=\lfloor i/X \rfloor \cdot X}^{\min((\lfloor i/X \rfloor + 1) \cdot X, n)} m[k]$。其中 $X$ 是图块大小(如 16 或 64)。
* 动态适应性:与 FlashDecoding++ 的静态最大值不同,$\hat{m}[i]$ 根据局部行分布动态调整。
数值稳定性保证
引入代理最大值必须避免两种失败情况:
1. Any-overflow:由于 $\hat{m}[i] \ge m[i]$,保证了 $\exp(S[i, j] - \hat{m}[i]) \le \exp(S[i, j] - m[i]) \le 1$,因此严格防止了溢出(Infinity-over-Infinity)。
2. All-underflow:即某行的所有指数项都下溢为零。尽管较大的 $\hat{m}[i]$ 增加了下溢概率,但对于典型分布(如 Gaussian),所有项同时下溢的联合概率渐近地小。如果遇到极端分布,实现可选择性回退到向量化计算。
算法流程
Algorithm 1 展示了张量化在线 Softmax 算法:
* Step 1: 计算 $X$-row 图块最大值 $\hat{m} \leftarrow \text{tilemax}(S, X)$。
* Step 2-3: 使用重新利用的张量缩放指令重缩放注意力输出 $O \leftarrow \exp(\hat{m}_{old} - \hat{m}) \cdot O$(利用统一的 $\hat{m}$)。
* Step 4-5: 分配代理最大值并使用重新利用的张量 FMA 指令重缩放注意力 Logits $Z \leftarrow \log_2(e) \cdot S - (\log_2(e) \cdot m)$。
* Step 6: 使用向量单元计算指数 $\tilde{P} \leftarrow \exp_2(Z)$(因为 $\exp$ 难以张量化)。
* Step 7: 使用重新利用的张量行和归约指令更新行和 $l$。
张量-向量并行调度(Scheduling)
全张量化的局限与混合并行策略
初步结果表明,仅进行全张量化带来的收益有限,因为当前硬件(如 A100)的 MMA 吞吐量尚未达到碾压向量单元的临界点(Effective throughput 均为 16 elem/cycle)。因此,作者提出了架构感知的调度技术,利用指令级并行(ILP)和线程级并行(TLP)来并行化张量和向量单元的 Softmax 计算。
Ampere GPU 的 ILP 调度
在 Ampere GPU 上,FlashAttention-T 将 Softmax 计算拆分为张量化和向量化部分,并在每个 Warp 内进行交错调度。
* 张量-向量拆分策略:基于 16 行图块方案,首先使用 Warp All-Reduce 实现 16 行最大值代理 $\hat{m}$。然后选择以下策略之一拆分 $S$ 和 $O$ 的重缩放操作:
* 水平拆分(Horizontal split):如 Fig 6(a) 所示,在 Warp 的矩阵图块内水平拆分,比例约为 1:1。
* 垂直拆分(Vertical split):如 Fig 6(b) 所示,跨分配给同一 Warp 的多个矩阵图块进行拆分。
Hopper GPU 的 TLP 调度
Hopper GPU 支持异步 WGMMA 指令,允许非阻塞 GEMM 执行和跨 Warpgroup 并行。
* 张量-向量拆分策略:由于 nvcc 编译器的限制(异步 WGMMA 要求累加器寄存器无依赖访问,否则会串行化),无法张量化 $S$ 或 $O$ 的重缩放(因为存在跨阶段寄存器依赖)。因此,在 Hopper 上仅张量化 $\tilde{P}$ 的行和归约(Row-sum summation)。这是一个叶子阶段操作,最大限度减少了依赖。
* 线程级并行(TLP):如 Fig 7(b) 所示,将重新利用的 WGMMA 行和归约指令调度到 Warpgroup 下一次迭代的 $QK^T$ 和 $PV$ WGMMA 批次中。这使得张量化计算能与其他 Warpgroup 的向量化 $S$ 和 $O$ 重缩放操作以 TLP 方式并行运行,显著减少向量间隔时间。
硬件配置:
软件配置与基线:
数据集与参数:
Attention 吞吐量性能
Fig 8 展示了不同配置下的吞吐量对比。FlashAttention-T 在绝大多数配置下均优于所有基线。
向量间隔比率评估
为了量化 Tensor 单元空闲时间的减少,计算了向量间隔比率(Vector Interval Ratio)。Fig 9 显示:
* Ampere GPU:FlashAttention-T 的 ILP 调度使得向量间隔比率比基线低 1.17–2.18 倍。
* Hopper GPU:TLP 调度效果更显著,将向量间隔比率降至 2.7%(相比基线大幅降低),这得益于更灵活的动态 Tensor-Vector 重叠。
消融实验
Fig 10 展示了各技术对性能的贡献(以 A100 为例):
* FA2+Max16:仅集成 16 行代理最大值 $\hat{m}$,获得 1%-3% 的提升(归功于 Warp All-Reduce 优于 Shuffle)。
* AllTensor:全张量化变体反而导致轻微性能下降,证实了单纯张量化因 Swap 开销和吞吐量瓶颈并非最优。
* FA-T (ILP):完整的 ILP 调度实现了最高 18.4% 的性能增益,证明了混合并行调度的有效性。
数值稳定性与精度
合成数据测试:Fig 11 显示,在不同离群值方差 $\sigma^2$ 下,FlashAttention-T 的 RMSE 均小于 $10^{-3}$,与 FA2 处于同一数量级,且未观测到数值失败(Numerical Failure)。在 $\sigma^2$ 较小时,误差主要来自 TF32 计算精度;在 $\sigma^2$ 较大时,误差主要由数值范围主导。
* 生成式基准测试:在 Llama3, Mistral, Qwen3 等模型上的测试表明(Table 2*),FlashAttention-T 的功能正确性指标(如 Pass@10, MMLU Score)与基线几乎完全一致,证实了其在真实应用中的可靠性。
本文提出了 FlashAttention-T,通过重新利用 GPU 张量 MMA 指令并配合张量化在线 Softmax 算法,成功将 Softmax 的关键计算转移至张量单元。结合架构感知的 ILP(Ampere)和 TLP(Hopper)调度策略,FlashAttention-T 有效解决了向量间隔瓶颈,在保证数值稳定性和精度的前提下,在多种 GPU 平台上显著提升了 Attention 的计算吞吐量。这项工作是向全张量化 Attention 执行迈出的重要一步。