FlashAttention-T: Towards Fully Tensorized Attention by Exploiting Tensor-Vector Parallelism

Jianxing Xu, Yuanbo Wen, Jun Bi, Ruibai Xu, Guanglin Xu, Rui Zhang, Wei Li, Ling Li, Tianshi Chen, Qi Guo, Yunji Chen
University of Science and Technology of China; Institute of Computing Technology, Chinese Academy of Sciences; Cambricon Technologies.

A1 主要贡献

本文针对现有的融合注意力机制(Fused Attention,如 FlashAttention-2/3)中存在的向量间隔(Vector Interval)瓶颈问题,提出了一种向全张量化注意力机制演进的方案 FlashAttention-T。主要贡献归纳如下:

  1. 核心问题识别与解决:现有的融合注意力实现将矩阵乘法(GEMM)和 Softmax 操作在计算上解耦,前者使用高性能张量单元(Tensor Cores),后者使用较慢的向量单元(CUDA Cores)。这种性能异构性导致了严重的向量间隔,即张量单元在等待向量单元完成 Softmax 时处于空闲状态(如图 1(a) 所示)。FlashAttention-T 通过将关键的 Softmax 原语卸载到空闲的张量单元上来解决这一瓶颈。
  2. 基于张量 MMA 的 Softmax 原语:提出了一系列操作数赋值方法,通过重新利用(Repurpose)GPU 的张量矩阵乘累加(MMA)指令来执行 Softmax 的关键原语(如逐元素缩放、融合乘加和行和归约),最大限度地利用硬件并减少开销。
  3. 张量化在线 Softmax 算法:设计了一种具有数值稳定性保证的张量化在线 Softmax 算法。该算法引入了代理最大值(Surrogate Maximum)的概念,以适应被重新利用的张量 MMA 指令对统一缩放因子的约束,从而使得在 FlashAttention 流程中可以张量化 Softmax 计算。
  4. 张量-向量并行调度:提出了架构感知的调度技术,通过利用指令级并行(ILP,针对 Ampere 架构)和线程级并行(TLP,针对 Hopper 架构),将张量化 Softmax 算法并行映射到张量和向量单元上,充分利用异构并行性。
图 1. (a) FlashAttention-3 和 (b) FlashAttention-T 在 Hopper GPU 上的执行时间线。FlashAttention-T 通过张量化 Softmax 计算并利用向量和张量化 Softmax 计算之间的线程级并行性来解决向量间隔瓶颈。
图 1. (a) FlashAttention-3 和 (b) FlashAttention-T 在 Hopper GPU 上的执行时间线。FlashAttention-T 通过张量化 Softmax 计算并利用向量和张量化 Softmax 计算之间的线程级并行性来解决向量间隔瓶颈。

A3 背景知识/关键Observation/设计原则

融合注意力的算法流程
融合注意力通过将计算并行化到 GPU 内核中的 Warpgroup(或 Threadblock)来执行,每个 Warpgroup 计算输出矩阵 $O$ 的一个行块。如 Fig 2(a) 所示,算法流程包含:❶ 计算注意力 Logit 块 $S$;❷ 更新行最大值 $m$;❸❺❻ 基于 $m$ 和旧的行最大值 $m_{old}$ 对 $S$ 进行指数化生成注意力分数块 $\tilde{P}$;❹ 重新缩放注意力输出 $O$;❼ 更新注意力分数行和 $l$;❽ 将结果累加到 $O$。这种分块方案确保中间结果驻留在寄存器中,为利用张量 MMA 指令进行张量化提供了基础。

图 2. (a) Warpgroup 内融合注意力迭代的算法流程。(b) 注意力计算空间中的算法流程迭代可视化。
图 2. (a) Warpgroup 内融合注意力迭代的算法流程。(b) 注意力计算空间中的算法流程迭代可视化。

Softmax 时间占比分析
在当前的范式中,Softmax 操作(max, fma, mul, add, exp)使用向量单元,而 GEMM 使用张量单元。对于序列长度 $N$ 和头维度 $h$,Softmax 的总操作数为 $4N^2 + \frac{N-1}{N}N^2h$,而 GEMM 为 $2N^2h$ FMAs。随着硬件发展(如 Hopper H100 相比 Ampere A100 吞吐量翻倍),张量单元与向量单元的吞吐量比率 $P_T/P_V$ 增加,导致 Softmax 的时间占比显著上升,成为计算瓶颈。

向量间隔瓶颈的揭示
通过分析 FlashAttention-2(Ampere 优化)和 FlashAttention-3(Hopper 优化)的执行时间线(见 Fig 3),作者揭示了向量间隔瓶颈:
* FlashAttention-2 (Fig 3a):采用顺序调度,GEMM 计算后必须等待向量单元完成 Softmax 才能进行下一次 GEMM,导致张量单元空闲时间 $t_{vec}$。
* FlashAttention-3 (Fig 3b):利用 Hopper 的异步 WGMMA 实现流水线调度,部分重叠了 GEMM 和 Softmax。然而,未重叠的 Softmax 计算仍导致显著的向量间隔。
* Table 1 的 Profiling 结果显示,向量间隔在 FlashAttention-2 中平均占迭代周期的 29.8%,在 FlashAttention-3(FP8-FP32)中升至 36.3%。

图 3. (a) Ampere GPU 上的 FlashAttention-2 和 (b) Hopper GPU 上的 FlashAttention-3 的执行时间线。由于完全依赖向量单元进行 Softmax 计算,两种实现都遭受向量间隔瓶颈的影响。
图 3. (a) Ampere GPU 上的 FlashAttention-2 和 (b) Hopper GPU 上的 FlashAttention-3 的执行时间线。由于完全依赖向量单元进行 Softmax 计算,两种实现都遭受向量间隔瓶颈的影响。

A2 方法细节

重新利用张量 MMA 指令

操作数赋值原理与映射
现代 GPU 的张量 MMA 指令将逻辑操作数矩阵划分为分布在线程间的片段(fragments)。定义 $v[i, j]$ 为逻辑矩阵 $v$ 在 $i$ 行 $j$ 列的元素,$v(u, t)$ 为线程 $u$ 持有的片段中的第 $t$ 个元素。映射函数 $\phi_v: v[i, j] \to v(u, t)$ 是特定于架构和操作数类型的(如图 Fig 4 所示)。为了重新利用 MMA 指令执行 Softmax 原语,需要确定输入片段的赋值,使得 MMA 输出片段 $D(v, t)$ 等于输入片段 $A$ 的置换后的逐元素缩放。

图 4. 现代 NVIDIA GPU 中张量 MMA 指令 HMMA.1688.F32.TF32 的逻辑矩阵元素到片段的映射。图中的颜色仅用于说明目的。
图 4. 现代 NVIDIA GPU 中张量 MMA 指令 HMMA.1688.F32.TF32 的逻辑矩阵元素到片段的映射。图中的颜色仅用于说明目的。

逐元素缩放与融合乘加(Scaling & FMA)
对于缩放因子 $\alpha$,目标是确定片段 $B$ 的赋值,使得 MMA 计算出的输出片段 $D$ 满足 $D(v, t) = \alpha A(\sigma(v), t)$,其中 $\sigma$ 是由张量单元内部片段映射引起的不可避免的置换(Permutation)。
* 置换开销最小化:为了恢复非置换输出(即 $D(v, t) = \alpha A(v, t)$),需要在线程寄存器内进行交换(swap)操作,这会产生开销。作者将此问题建模为约束优化问题:最小化置换 $\sigma$ 的 Cayley 距离(即将其转换为恒等置换所需的最小交换次数)。
* 最优赋值Fig 5(a) 展示了最优的片段 $B$ 赋值方法,对于 $\sigma = (1\ 2)$,只需一次交换即可恢复。
* 融合乘加扩展:基于此,通过利用张量单元的累加器添加偏移量 $C(v, t)$,该方法可扩展为实现 $D(v, t) = \alpha A(\sigma(v), t) + C(v, t)$ 的融合乘加操作。

图 5. 重新利用张量 MMA 指令 HMMA.1688.F32.TF32 以实现 (a) 因子为 $\alpha$ 的逐元素缩放或融合乘加,以及 (b) 行和归约的最佳片段 $B$ 值赋值方法。
图 5. 重新利用张量 MMA 指令 HMMA.1688.F32.TF32 以实现 (a) 因子为 $\alpha$ 的逐元素缩放或融合乘加,以及 (b) 行和归约的最佳片段 $B$ 值赋值方法。

行和归约(Row-Sum Reduction)
行和归约涉及对注意力分数矩阵 $\tilde{P}$ 的行进行求和。传统的向量化方法需要线程内加法和显式的线程间 All-Reduce,带来同步和通信开销。作者提出利用张量 MMA 指令更高效地执行此操作:
* 输入构造:设置输入片段 $A$ 来自矩阵 $\tilde{P}$,将片段 $B$ 设置为如 Fig 5(b) 所示的特定模式,并将累加器 $C$ 置零。
* 输出计算:MMA 输出片段 $D'$ 本质上包含了部分和。通过执行两次线程内加法 $s(0, t) = D'(0, t) + D'(2, t)$ 和 $s(1, t) = D'(1, t) + D'(3, t)$,即可得到最终的归约结果。
* 优势:这种方法避免了显式的线程同步和通信,且置换 $\sigma$ 被吸收到线程内加法中,不产生额外的交换开销。

讨论
这些方法不需要额外的数据拷贝开销,因为它们直接在寄存器中操作。主要限制是算法约束:缩放和 FMA 要求 MMA 指令的所有行共享统一的缩放因子 $\alpha$。这需要后续的张量化算法来配合。

张量化在线 Softmax 算法

统一缩放因子的挑战与代理最大值
原始融合注意力算法中的输出重缩放步骤 $O = \exp(m_{old} - m) \cdot O$ 需要每行的缩放因子 $\exp(m_{old} - m)$,这与张量 MMA 缩放指令要求的统一缩放因子冲突。
* 代理最大值 $\hat{m}$:为了解决此冲突,作者提出了代理最大值(Surrogate Maximum) $\hat{m}[i]$,定义为包含第 $i$ 行的 $X$ 行图块(tile)内的最大值:$\hat{m}[i] = \max_{k=\lfloor i/X \rfloor \cdot X}^{\min((\lfloor i/X \rfloor + 1) \cdot X, n)} m[k]$。其中 $X$ 是图块大小(如 16 或 64)。
* 动态适应性:与 FlashDecoding++ 的静态最大值不同,$\hat{m}[i]$ 根据局部行分布动态调整。

数值稳定性保证
引入代理最大值必须避免两种失败情况:
1. Any-overflow:由于 $\hat{m}[i] \ge m[i]$,保证了 $\exp(S[i, j] - \hat{m}[i]) \le \exp(S[i, j] - m[i]) \le 1$,因此严格防止了溢出(Infinity-over-Infinity)。
2. All-underflow:即某行的所有指数项都下溢为零。尽管较大的 $\hat{m}[i]$ 增加了下溢概率,但对于典型分布(如 Gaussian),所有项同时下溢的联合概率渐近地小。如果遇到极端分布,实现可选择性回退到向量化计算。

算法流程
Algorithm 1 展示了张量化在线 Softmax 算法:
* Step 1: 计算 $X$-row 图块最大值 $\hat{m} \leftarrow \text{tilemax}(S, X)$。
* Step 2-3: 使用重新利用的张量缩放指令重缩放注意力输出 $O \leftarrow \exp(\hat{m}_{old} - \hat{m}) \cdot O$(利用统一的 $\hat{m}$)。
* Step 4-5: 分配代理最大值并使用重新利用的张量 FMA 指令重缩放注意力 Logits $Z \leftarrow \log_2(e) \cdot S - (\log_2(e) \cdot m)$。
* Step 6: 使用向量单元计算指数 $\tilde{P} \leftarrow \exp_2(Z)$(因为 $\exp$ 难以张量化)。
* Step 7: 使用重新利用的张量行和归约指令更新行和 $l$。

Algorithm 1 Tensorized Online Softmax for Fused Attention
Algorithm 1 Tensorized Online Softmax for Fused Attention

张量-向量并行调度(Scheduling)

全张量化的局限与混合并行策略
初步结果表明,仅进行全张量化带来的收益有限,因为当前硬件(如 A100)的 MMA 吞吐量尚未达到碾压向量单元的临界点(Effective throughput 均为 16 elem/cycle)。因此,作者提出了架构感知的调度技术,利用指令级并行(ILP)和线程级并行(TLP)来并行化张量和向量单元的 Softmax 计算。

Ampere GPU 的 ILP 调度
在 Ampere GPU 上,FlashAttention-T 将 Softmax 计算拆分为张量化和向量化部分,并在每个 Warp 内进行交错调度。
* 张量-向量拆分策略:基于 16 行图块方案,首先使用 Warp All-Reduce 实现 16 行最大值代理 $\hat{m}$。然后选择以下策略之一拆分 $S$ 和 $O$ 的重缩放操作:
* 水平拆分(Horizontal split):如 Fig 6(a) 所示,在 Warp 的矩阵图块内水平拆分,比例约为 1:1。
* 垂直拆分(Vertical split):如 Fig 6(b) 所示,跨分配给同一 Warp 的多个矩阵图块进行拆分。

图 6. Ampere GPU 上 FlashAttention-T 的水平和垂直拆分策略,用于 ILP 调度。
图 6. Ampere GPU 上 FlashAttention-T 的水平和垂直拆分策略,用于 ILP 调度。

Hopper GPU 的 TLP 调度
Hopper GPU 支持异步 WGMMA 指令,允许非阻塞 GEMM 执行和跨 Warpgroup 并行。
* 张量-向量拆分策略:由于 nvcc 编译器的限制(异步 WGMMA 要求累加器寄存器无依赖访问,否则会串行化),无法张量化 $S$ 或 $O$ 的重缩放(因为存在跨阶段寄存器依赖)。因此,在 Hopper 上仅张量化 $\tilde{P}$ 的行和归约(Row-sum summation)。这是一个叶子阶段操作,最大限度减少了依赖。
* 线程级并行(TLP):如 Fig 7(b) 所示,将重新利用的 WGMMA 行和归约指令调度到 Warpgroup 下一次迭代的 $QK^T$ 和 $PV$ WGMMA 批次中。这使得张量化计算能与其他 Warpgroup 的向量化 $S$ 和 $O$ 重缩放操作以 TLP 方式并行运行,显著减少向量间隔时间。

图 7. FlashAttention-T 在 (a) Ampere 和 (b) Hopper GPU 上的执行时间线。
图 7. FlashAttention-T 在 (a) Ampere 和 (b) Hopper GPU 上的执行时间线。

A4 实验环境

A4 实验结果

Attention 吞吐量性能
Fig 8 展示了不同配置下的吞吐量对比。FlashAttention-T 在绝大多数配置下均优于所有基线。

图 8. FlashAttention-T 和基线在 Ampere 和 Hopper GPU 上不同配置下的 Attention 吞吐量。
图 8. FlashAttention-T 和基线在 Ampere 和 Hopper GPU 上不同配置下的 Attention 吞吐量。

向量间隔比率评估
为了量化 Tensor 单元空闲时间的减少,计算了向量间隔比率(Vector Interval Ratio)。Fig 9 显示:
* Ampere GPU:FlashAttention-T 的 ILP 调度使得向量间隔比率比基线低 1.17–2.18 倍
* Hopper GPU:TLP 调度效果更显著,将向量间隔比率降至 2.7%(相比基线大幅降低),这得益于更灵活的动态 Tensor-Vector 重叠。

图 9. 基线和 FlashAttention-T 之间的向量间隔比率比较。注意,Ampere GPU 上的结果是基于测量的时钟值估算的。
图 9. 基线和 FlashAttention-T 之间的向量间隔比率比较。注意,Ampere GPU 上的结果是基于测量的时钟值估算的。

消融实验
Fig 10 展示了各技术对性能的贡献(以 A100 为例):
* FA2+Max16:仅集成 16 行代理最大值 $\hat{m}$,获得 1%-3% 的提升(归功于 Warp All-Reduce 优于 Shuffle)。
* AllTensor:全张量化变体反而导致轻微性能下降,证实了单纯张量化因 Swap 开销和吞吐量瓶颈并非最优。
* FA-T (ILP):完整的 ILP 调度实现了最高 18.4% 的性能增益,证明了混合并行调度的有效性。

图 10. FlashAttention-T 及其消融变体的 Attention 吞吐量。结果归一化为 FlashAttention-2。
图 10. FlashAttention-T 及其消融变体的 Attention 吞吐量。结果归一化为 FlashAttention-2。

数值稳定性与精度
合成数据测试Fig 11 显示,在不同离群值方差 $\sigma^2$ 下,FlashAttention-T 的 RMSE 均小于 $10^{-3}$,与 FA2 处于同一数量级,且未观测到数值失败(Numerical Failure)。在 $\sigma^2$ 较小时,误差主要来自 TF32 计算精度;在 $\sigma^2$ 较大时,误差主要由数值范围主导。
*
生成式基准测试:在 Llama3, Mistral, Qwen3 等模型上的测试表明(Table 2*),FlashAttention-T 的功能正确性指标(如 Pass@10, MMLU Score)与基线几乎完全一致,证实了其在真实应用中的可靠性。

图 11. 与 A100 GPU 上的参考 FP64 Attention 实现相比,各种消融变体的 RMSE。
图 11. 与 A100 GPU 上的参考 FP64 Attention 实现相比,各种消融变体的 RMSE。

A5 结论

本文提出了 FlashAttention-T,通过重新利用 GPU 张量 MMA 指令并配合张量化在线 Softmax 算法,成功将 Softmax 的关键计算转移至张量单元。结合架构感知的 ILP(Ampere)和 TLP(Hopper)调度策略,FlashAttention-T 有效解决了向量间隔瓶颈,在保证数值稳定性和精度的前提下,在多种 GPU 平台上显著提升了 Attention 的计算吞吐量。这项工作是向全张量化 Attention 执行迈出的重要一步。