FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling

作者/机构: Ted Zadouri1,6, Markus Hoehnerbach2, Jay Shah*3, Timmy Liu4, Vijay Thakkar2,5, Tri Dao1,6 (1普林斯顿大学 2Meta 3Colfax Research 4NVIDIA 5佐治亚理工学院 6Together AI)

A1 主要贡献

本文旨在解决大型语言模型和长上下文应用中,作为Transformer架构核心层的Attention机制的性能瓶颈问题。随着硬件从Hopper架构(如H100)向Blackwell架构(如B200, GB200)演进,硬件性能呈现出非对称扩展的特点:张量核心(Tensor Core)的吞吐量翻倍,而其他功能单元(如共享内存带宽、指数运算单元)的扩展速度较慢或保持不变。这种不平衡的硬件扩展使得非矩阵乘法(non-MMA)资源成为新的性能瓶颈。

为了应对Blackwell GPU上这些变化的瓶颈,本文提出了FlashAttention-4,通过算法与核函数(kernel)实现的协同设计来解决问题。本文的核心贡献如下:

  1. 为实现最大化重叠而重新设计的流水线:为前向和后向传播过程开发了新的软件流水线。这些流水线利用了Blackwell架构中完全异步的矩阵乘法累加(MMA)操作和更大的瓦片(tile)尺寸,以最大化张量核心、softmax计算和内存操作之间的重叠。

  2. 缓解指数运算单元瓶颈:在前向传播中,通过在FMA单元上使用多项式逼近来实现软件模拟的指数函数,从而提高了指数运算的吞吐量。同时,引入了条件性softmax重缩放(conditional softmax rescaling)技术,以跳过不必要的重缩放操作。

  3. 减少共享内存(Shared Memory)流量:在后向传播中,利用新增的张量内存(Tensor Memory)来存储更多的中间结果,减少了对共享内存的访问。此外,还利用Blackwell的2-CTA MMA模式,让每个CTA(Cooperative Thread Array)分阶段加载一半的B操作数,进一步减少共享内存流量,并基于此重构了dQ步骤,将原子归约(atomic reductions)的数量减半。本文还实现了一个性能开销极小的确定性执行模式,以支持强化学习等应用的可复现训练。

  4. 改进的调度和资源分配:针对Blackwell的资源限制和更大的瓦片尺寸,开发了新的CTA调度策略和寄存器分配方案。

除了算法创新,FlashAttention-4完全在嵌入Python的CuTe-DSL中实现,与传统的基于C++模板的方法相比,编译时间加快了20-30倍,同时保持了完整的表达能力。这个框架显著提高了开发效率,降低了技术门槛。

实验证明,在B200 GPU上,FlashAttention-4(BF16精度)相比cuDNN 9.13实现了高达1.3倍的加速,相比Triton实现了高达2.7倍的加速,达到了1613 TFLOPs/s(71%的理论利用率),并在长序列场景下优于其他实现。

A3 背景知识

2.1 多头注意力(Multi-Head Attention)

设Q、K、V ∈ R^(N×d) 分别是单个头的查询、键和值输入序列,其中N是序列长度,d是头维度。注意力输出 O ∈ R^(N×d) 的计算方式如下:

S = αQKᵀ ∈ R^(N×N), P = softmax(S) ∈ R^(N×N), O = PV ∈ R^(N×d)
S = αQKᵀ ∈ R^(N×N), P = softmax(S) ∈ R^(N×N), O = PV ∈ R^(N×d)

其中,softmax是按行应用的,缩放因子 α = 1/√d。在实践中,为了数值稳定性,会从S中减去每行的最大值(rowmax(S))。对于多头注意力(MHA),每个头都有自己的投影矩阵,并且计算在多个头和批次(batch)之间是并行的。

给定输出梯度 dO ∈ R^(N×d),反向传播计算如下:

$$\begin{aligned} \begin{aligned} \mathrm{d}\mathbf{V} &= \mathbf{P}^{\top}\mathrm{d}\mathbf{O}, \quad \mathrm{d}\mathbf{P} = \mathrm{d}\mathbf{O}\mathbf{V}^{\top}, \\ \mathrm{d}\mathbf{S} &= \mathrm{dsoftmax}(\mathrm{d}\mathbf{P}), \\ \mathrm{d}\mathbf{Q} &= \alpha\mathrm{d}\mathbf{S}\mathbf{K}, \quad \mathrm{d}\mathbf{K} = \alpha\mathrm{d}\mathbf{S}^{\top}\mathbf{Q}, \end{aligned} \end{aligned}$$

其中,dsoftmax(dP) 表示逐行的softmax梯度,对于 p = softmax(s),其梯度为 ds = (diag(p) − pp⊤)dp。

2.2 GPU硬件特性和执行模型

本节描述了与FlashAttention-4相关的GPU执行模型,重点关注NVIDIA Blackwell架构(B200 & GB200),并强调了其与前代Hopper架构的关键差异,这些差异是FlashAttention-4优化的动因。

内存层级。GPU的内存组织成一个层次结构,容量与带宽成反比。全局内存(GMEM),也称为HBM,是所有流式多处理器(SM)都可以访问的片外DRAM。来自GMEM的数据会透明地缓存在片上的L2缓存中。接着,每个SM都包含一个由程序员管理、高度分区的片上小缓存,称为共享内存(SMEM)。最后是每个SM内的寄存器文件。Blackwell架构引入了一个新的内存级别,称为张量内存(TMEM),这是每个SM上一个256 KB的片上内存,专门用于存储张量核心操作的中间结果。与共享内存不同,TMEM是warp同步的,并与张量核心紧密耦合,使得矩阵乘法累加(MMA)单元能够直接将输出写入TMEM而无需消耗寄存器。这缓解了困扰Hopper核函数的极端寄存器压力,并允许使用更大的瓦片(tile)尺寸。TMEM以32列(16 KB)的粒度进行分配,并需要程序员进行显式的分配、释放和数据移动管理。

线程层级。GPU的编程模型围绕称为线程的逻辑执行单元分组进行组织。从最细到最粗的粒度,线程层级包括线程(threads)、线程束(warps,32个线程)、线程束组(warpgroups,4个连续的warps)、线程块(threadblocks,即合作线程阵列或CTAs)、线程块集群(threadblock clusters)和网格(grids)。同一CTA中的线程被共同调度到同一个SM上,同一集群中的CTA被共同调度到同一个GPC上。SMEM可由CTA内的所有线程直接寻址,而每个线程最多拥有256个私有寄存器(RMEM)。

张量核心与增强的异步性。Blackwell配备了第五代张量核心,其操作的瓦片尺寸远大于以往架构。每个MMA张量核心指令处理128 × N的瓦片(通常N=128或256),而Hopper上是64 × N。至关重要的是,Blackwell的MMA指令将其输出异步地直接写入TMEM,而Hoomer的MMA则写入寄存器。这种完全的异步性使得计算与其他操作之间能有更好的重叠,因为MMA单元不再因等待寄存器写回而阻塞。硬件对异步性的支持使得warp专业化核函数成为可能,即一个CTA的warps被划分为生产者或消费者角色,分别只负责发出数据移动或计算指令【索引1,CudaDMA: Optimizing GPU Memory Bandwidth via Warp Specialization+2011+SC '11+https://doi.org/10.1145/2063384.2063400】 。

2-CTA张量核心。Blackwell支持一种2-CTA张量核心MMA模式,其中同一线程块集群内的两个CTA(CTA对)协同执行单个MMA操作,允许该操作从两个CTA读取和写入张量内存。由CTA对中的一个线程发起MMA,但其对等CTA必须已启动并在操作进行期间保持活动状态。与单CTA MMA将M维度限制为128相比,配对模式通过在M维度上将A瓦片和累加器划分到CTA对中,并在N维度上将B瓦片划分到两个CTA中,支持M=128或256。这样,每个CTA只需在自己的共享内存中暂存一半的B,而硬件在乘法过程中会使用合并后的B瓦片。这减少了冗余的共享内存容量和带宽需求,但由于这些操作会跨CTA对访问张量内存,核函数必须以固定的CTA对形式启动,并在整个核函数中对张量内存和张量核心操作使用一致的2-CTA模式。

变化的瓶颈。Blackwell体现的一个关键趋势是张量核心的吞吐量扩展速度快于其他功能单元。与Hopper相比,Blackwell的FP16/BF16张量核心吞吐量翻了一番(每个GPU从1 PFLOPS【索引17,Nvidia H100 tensor core GPU architecture+2022+】增加到2.25 PFLOPS【索引19,Nvidia Blackwell architecture technical brief+2024+https://nvdam.widen.net/s/ xqt56dflgh/nvidia-blackwell-architecture-technical-brief】),但共享内存带宽和指数单元吞吐量保持不变或扩展得更慢。这种不平衡将性能瓶颈从矩阵乘法转移到了共享内存流量和像softmax这样的非矩阵乘法操作上。正如我们在3.1和3.2节的roofline分析所示,这要求精心的核函数设计,以最大化MMA操作与这些瓶颈资源之间的重叠。

B200(和GB200)上几个硬件组件的吞吐量如下
1. 张量核心:BF16 MMA的吞吐量为8192 ops/clock/SM,是Hopper(4096 ops/clock/SM)的两倍。这可以从理论最大FLOPS推导出来:2.25 PFLOPS / 1850 Mhz时钟速度 / 148 SMs = 8192 ops/clock/SM。
2. 指数单元:B200和GB200上的多功能单元(MUFU)可以执行16 ops/clock/SM,与Hopper相同【索引18,CUDA Programming Guide Version 12.4+2024+https://docs.nvidia.com/ cuda/cuda-c-programming-guide/index.html】。我们注意到,B300和GB300 GPU的指数吞吐量已翻倍至32 ops/clock/SM,但在撰写本文时这些GPU尚未广泛可用。
3. SMEM:根据微基准测试【索引15,Dissecting the nvidia hopper architecture through microbenchmarking and multiple level analysis+2025+】,读取吞吐量为128 bytes/clock/SM,与Hopper相同。

硬件趋势分析。我们可以看到,Blackwell上的MMA吞吐量相比Hopper翻了一番,但其他硬件单元的提速并非同步。这反映了加速器设计的一个更广泛趋势:在相似的功耗/硅片面积约束下,通过提高最重要组件(通常是矩阵乘法单元)的吞吐量来获得更高的性能。

A2 方法细节

3.1 注意力前向传播

我们首先进行roofline分析,以揭示注意力前向传播的瓶颈,这为我们新的流水线设计,以及为提高指数单元吞吐量和避免大部分softmax重缩放步骤而对FlashAttention算法进行的修改提供了动机。

3.1.1 性能分析(Feeds and Speeds)

我们首先通过分析基于矩阵乘法单元(张量核心)、共享内存(smem)和指数单元吞吐量的roofline来为我们的核函数设计和优化提供直觉。 我们注意到这是一个简化的分析,并未考虑GPU中的所有资源(例如,浮点数学、寄存器带宽、L2带宽)。尽管如此,它仍然可以识别出瓶颈。

设Q和K序列长度维度的瓦片形状为M × N,头维度为d。 我们分析计算和内存流量需求以确定性能瓶颈。

MMA计算。前向传播每次迭代执行两次矩阵乘法累加(MMA)操作:$QK^⊤$(从M × d和d × N的输入计算M × N的输出)和PV(从M × N和N × d的输入计算M × d的输出)。每次MMA需要2MNd个浮点操作。张量核心吞吐量为每周期8192 FLOPs,总计算时间为:

$$T_{\mathrm{MMA}} = \frac{4MNd}{8192} \text{ cycles.}$$

共享内存流量。在两次MMA中,一次是shared-shared (SS) 类型,其中两个操作数都从共享内存读取 ($QK^⊤$);另一次是tensor-shared (TS) 类型,其中操作数A从张量内存读取,操作数B从共享内存读取 (PV)。由于每个MMA指令操作的瓦片大小为128 × 128,计算一个M × N的输出需要⌈M/128⌉ × ⌈N/128⌉次MMA指令。关键在于,当需要多次MMA指令时,共享内存中的操作数会被多次读取。

对于$QK^⊤$ (SS),计算M × N的输出需要⌈M/128⌉ × ⌈N/128⌉次MMA指令,每次从共享内存中读取一个128 × d的Q块和一个d × 128的$K^⊤$块。总共享内存读取量为⌈M/128⌉ × ⌈N/128⌉ × (128d + 128d) = ⌈M/128⌉⌈N/128⌉ × 256d个元素。对于PV (TS),计算M × d的输出需要⌈M/128⌉ × ⌈d/128⌉次MMA指令,每次从共享内存中读取一个N × 128的V块,总共⌈M/128⌉ × ⌈d/128⌉ × 128N个元素。假设每个元素2字节(bf16),带宽为每周期128字节,共享内存(Tsmem)的读取时间为:

$$= 2\left\lceil\frac{M}{128}\right\rceil\left\lceil\frac{N}{128}\right\rceil 256 d+2\left\lceil\frac{M}{128}\right\rceil\left\lceil\frac{d}{128}\right\rceil 128 N = \frac{3 M N d}{8192} \text { cycles }$$

(假设M, N, d是128的倍数)。

指数单元。指数单元计算softmax所需的逐元素操作。前向传播需要在M × N个值上进行指数运算(对应于注意力矩阵S)。其吞吐量为每周期16次操作,因此指数单元需要的时间为:

$$T_{\exp }=\frac{M N}{16} \text { cycles. }$$

表1总结了两种典型瓦片配置的分析。对于M = N = d = 128,资源是相对平衡的,共享内存(768周期)略低于MMA计算和指数单元(均为1024周期)。对于更大的瓦片尺寸M = 256, N = d = 128,由于MMA操作数被多次读取,共享内存流量增加到1536周期,而MMA计算和指数单元则翻倍至2048周期。这一分析促使我们的核函数设计旨在:(1)使用大瓦片尺寸并最大化MMA操作与softmax计算之间的重叠;(2)通过使用其他硬件单元来提高指数运算的吞吐量;(3)减少不必要的非矩阵乘法操作的时间。

FlashAttention-4 前向流水线。上标H表示对应于“高”Q瓦片的矩阵,上标L表示对应于“低”Q瓦片的矩阵。每个Q瓦片对应128个查询令牌。
FlashAttention-4 前向流水线。上标H表示对应于“高”Q瓦片的矩阵,上标L表示对应于“低”Q瓦片的矩阵。每个Q瓦片对应128个查询令牌。
表1:注意力前向传播的roofline分析(周期)。对于两种瓦片尺寸,MMA计算和指数单元都是主要瓶颈。
表1:注意力前向传播的roofline分析(周期)。对于两种瓦片尺寸,MMA计算和指数单元都是主要瓶颈。

3.1.2 用于重叠矩阵乘法和softmax的新流水线

由于Blackwell架构再次将张量核心的浮点运算能力翻倍,因此精心设计以重叠softmax和张量核心操作比在Hopper上更为关键。 我们沿用类似于FA-3的乒乓调度(ping-pong schedule),即每个线程块计算两个输出瓦片。当一个瓦片的张量核心操作在执行时,另一个瓦片则进行softmax计算。Hopper张量核心将累加器保存在寄存器中,每个行有四个线程以交错模式工作;而Blackwell张量核心则将其累加器保存在张量内存中。此外,Blackwell上的单个累加器瓦片大小为128x128元素,而Hopper的瓦片大小为64x128。

在这种瓦片上分配工作的自然方式是使用两个各含128个线程的线程束组(warpgroups),每个线程处理一整行。 这样就无需进行warp间的shuffle操作来规约行最大值,也无需每个线程拥有多个统计寄存器。与FA-3一样,我们显式地同步两个softmax线程束组,使其在临界区(即指数计算部分)不发生重叠。每个softmax线程束组首先将整行加载到寄存器中,然后计算最大值,接着计算softmax(即减去最大值、重缩放、求指数、转换为输入精度),最后计算行和。

与FA-3的另一个区别是,由于我们通过张量内存而非寄存器文件传输P,我们可以将输出的重缩放操作解耦到一个单独的“校正”线程束组中,从而将其移出关键路径。

为了实现这种流水线重叠,有多种张量内存分区方案是可行的。 所有方案都必须为两个输出瓦片分配空间,这样(在头维度为128时)就剩下了一半的张量内存用于存储S和P。这部分内存可以存储两份S或四份P(假设输入是FP16或BF16的张量核心)。这给我们留下了大致两种剩余张量内存的分区选项:一个S瓦片和两个P瓦片,或者两个与P重叠的S瓦片。我们选择后者,因为它允许我们通过立即计算两个S瓦片来启动我们的软件流水线。它还留下了一些张量内存,用于向校正线程束组传达重缩放的统计信息。

Blackwell更大的瓦片尺寸和所选的线程分配方式带来一个问题: 除非我们从张量内存重新加载,否则必须在寄存器中保存一整行128个元素。考虑到我们使用了两个softmax线程束组、一个校正线程束组和一个用于驱动张量核心和TMA单元的线程束组,为softmax分配足够的寄存器并防止寄存器溢出至关重要。对于BF16输入数据类型,我们需要128个寄存器来保存输入,可能还需要64个寄存器来保存输出(外加杂项和临时寄存器)。为了减轻寄存器压力,我们分阶段存储P:前四分之三一次性存储(并触发相应的MMA操作),最后四分之一则分开存储。

3.1.3 指数函数的模拟

指数吞吐量瓶颈。在现代GPU上,指数函数由多功能单元(MUFU)计算,其吞吐量远低于用于矩阵乘法的张量核心。在B200和GB200 GPU上,MUFU每时钟周期每个SM提供16次操作,而矩阵乘法为8192次。由于softmax计算需要大量指数运算,这种差异使指数函数成为注意力核函数中的一个关键瓶颈。

通过多项式逼近进行软件模拟。为了提高指数吞吐量,我们使用浮点FMA单元实现了$2^x$的软件模拟,这些FMA单元可以与MUFU并行工作。我们使用了经典的范围缩减技术(Cody-Waite),然后进行多项式逼近【索引16,Handbook of Floating-Point Arithmetic+2018+Birkhäuser Boston】。关键思想是将指数计算分解为:

$$2^{x}=2^{\lfloor x\rfloor} 2^{x-\lfloor x\rfloor}$$

其中 ⌊x⌋ 是整数部分,而 x − ⌊x⌋ ∈ [0, 1) 是小数部分。

整数部分$2^⌊x⌋$可以通过对IEEE 754浮点表示进行位操作来高效计算。 由于指数域直接表示2的幂,计算$2^⌊x⌋$相当于对指数位进行移位和加法操作,这可以使用整数ALU指令完成。

对于小数部分,我们使用多项式逼近$2^{x_{frac}}$,其中$x_{frac} ∈ [0, 1)$:

$$2^{x_{\text{frac}}} \approx \sum_{i=0}^{n} p_i x_{\text{frac}}^i$$

其中$p_0 = 1.0$,其余系数通过Sollya软件包【索引4,Sollya: An environment for the development of numerical codes+2010+ICMS 2010】计算得出,以最小化在[0, 1)上的相对逼近误差。多项式求值使用霍纳(Horner)法则和FMA指令,以实现高吞吐量。

完整的算法流程如下
1. 将x裁剪至不小于-127,以避免下溢。
2. 使用向下取整模式计算⌊x⌋:将$2^{23} + 2^{22}$加到x上(强制将小数位移入尾数),然后在使用向下取整模式减去它。
3. 计算小数部分:$x_{frac} = x − ⌊x⌋$。
4. 求值多项式得到$2^{x_{frac}}$。
5. 合并整数和小数部分:将⌊x⌋移入指数域,并加上$2^{x_{frac}}$的尾数位。

通过将指数计算分布到MUFU和FMA单元上,这种方法有效地提高了指数吞-吐量,缓解了注意力计算中的一个关键瓶颈。

部分模拟。尽管多项式模拟提高了指数吞吐量,但它也带来了成本:需要额外的寄存器(用于保存中间值和系数)、更高的寄存器带宽消耗以及比MUFU指令更长的延迟。对所有指数求值都使用模拟会增加寄存器压力,并可能导致寄存器溢出,从而抵消吞吐量优势。因此,我们只对每行softmax中的一部分元素(10-25%)应用模拟,其余元素通过硬件MUFU.EX2指令计算。具体比例根据给定瓦片配置的MMA和指数吞吐量比率进行经验性调整。

数值精度。表2比较了不同阶数多项式逼近与硬件MUFU.EX2指令的精度,测试基于400万个在[0, 1)范围内的随机输入。我们报告了两个指标:FP32级别的误差(任何量化之前)和BF16级别的误差(将FP32输出四舍五入到BF16后),两者都与FP64参考值进行比较。在FP32级别,3阶多项式的最大相对误差为8.8 × 10⁻⁵,比硬件高约600倍。然而,在四舍五入到BF16后,误差几乎无法区分:对于所有阶数≥3的多项式,BF16的量化误差(~3.9 × 10⁻³)主导了多项式逼近误差。在99%的输入上,3阶多项式与硬件的差异在1个BF16 ULP(Unit in the Last Place)之内,这对于softmax输出以BF16精度被消费的注意力计算来说是足够的。更高阶的多项式缩小了FP32的差距:5阶多项式在最大相对误差上与硬件的差距在2倍以内,但每次求值需要额外两条FMA指令。

表2:在[0, 1)上对400万随机输入的2x多项式模拟的精度,与FP64参考值比较。FP32列测量原始多项式输出;BF16列测量四舍五入到BF16后的结果。对于所有阶数≥3的多项式,BF16的量化误差占主导。
表2:在[0, 1)上对400万随机输入的2x多项式模拟的精度,与FP64参考值比较。FP32列测量原始多项式输出;BF16列测量四舍五入到BF16后的结果。对于所有阶数≥3的多项式,BF16的量化误差占主导。

3.1.4 跳过在线softmax重缩放

FlashAttention在线softmax。FlashAttention以块(block)为单位计算注意力$softmax(QK^⊤)V$,以最小化内存流量。为了保证数值稳定性,算法在处理块时维护着运行统计量。在计算块j时,令$S_j = QK_j^⊤$为该块的注意力分数。在线softmax算法跟踪以下统计量:

$$\begin{aligned} \begin{aligned} m_j &= \max(m_{j-1}, \text{rowmax}(S_j)) \\ \ell_j &= e^{m_{j-1}-m_j} \ell_{j-1} + \text{rowsum}(e^{S_j-m_j}) \end{aligned} \end{aligned}$$

其中$m_j$是运行最大值,ℓj是指数的运行和(归一化因子)。中间输出$O_j$的更新方式为:$O_j = e^{m_{j-1}−m_j} O_{j-1} + e^{S_j−m_j} V_j$。重缩放因子$e^{m_{j-1}−m_j}$通过在遇到更大值时对先前结果进行重新归一化来确保数值稳定性。

条件性重缩放。$e^{m_{j-1}−m_j} O_{j-1}$这一步需要进行一次向量乘法。我们提出了两个简单的观察:
1. 只有当$m_j > m_{j-1}$时,即当发现新的更大值时,重缩放才是必要的。
2. 我们可以容忍重缩放中的一些“松弛”:仅当$m_j - m_{j-1} > \tau$时才进行重缩放,其中$\tau$是一个阈值(通常设为$log_2(256) = 8.0$,对应于256.0的重缩放因子)。只要我们持续跟踪统计数据(我们已经完成的总缩放量),我们最终仍然可以得到真正的分母,从而得到正确的最终输出。

在FlashAttention-4中,我们修改算法如下:

$$\begin{aligned} O_j = \begin{cases} e^{m_{j-1}-m_j} O_{j-1} + e^{S_j-m_j} V_j & \text{if } m_j - m_{j-1} > \tau \\ O_{j-1} + e^{S_j-m_{j-1}} V_j & \text{otherwise} \end{cases} \end{aligned}$$

当$m_j - m_{j-1} ≤ \tau$时,我们跳过更新m,并继续使用$m_{j-1}$。这保持了计算的正确性,因为在计算结束时,所有累积的值都会通过真正的最大值$m_{final}$和最终的归一化因子$ℓ_{final}$进行重新归一化:

$$\text{Output} = \frac{1}{\ell_{\text{final}}} O_{\text{final}}$$

这种修改显著减少了重缩放操作的次数,同时保持了数值精度,因为最终的归一化步骤会纠正因跳过中间重缩放而引入的任何微小偏差。

在实践中,为了避免warp发散,当warp中的任何一个线程需要重缩放时,我们就会对整个warp进行重缩放。

3.2 注意力反向传播

3.2.1 性能分析(Feeds and Speeds)

与前向传播类似,我们首先通过分析基于矩阵乘法单元(张量核心)、共享内存(smem)和指数单元吞吐量的roofline来为我们的核函数设计和优化提供直觉。

设Q和K序列长度维度的瓦片形状为M × N,头维度为d。 我们分析计算和内存流量需求以确定性能瓶颈。与前向传播不同,我们假设M = N = d = 128以简化smem周期计数的公式,但为了清晰起见保留了变量名。

MMA计算。反向传播每次迭代执行五次矩阵乘法累加(MMA)操作。每次MMA涉及一个M × N矩阵、一个M × d矩阵和一个d × N矩阵(输出矩阵各不相同),需要2MNd个浮点操作。张量核心吞吐量为每周期8192 FLOPs,总计算时间为:

FlashAttention-4反向传播计算图(5次MMA操作+2次逐元素操作),展示了在prologue、主循环和tail阶段中1-CTA MMA模式的软件流水线顺序。
FlashAttention-4反向传播计算图(5次MMA操作+2次逐元素操作),展示了在prologue、主循环和tail阶段中1-CTA MMA模式的软件流水线顺序。

$$T_{\mathrm{MMA}}=\frac{10MNd}{8192}\text{ cycles.}$$

共享内存流量。在五次MMA中,有三次——$S^⊤ = KQ^⊤$, $dP^⊤ = VdO^⊤$, 和 $dQ = dSK$——是shared-shared (SS)操作,其中两个操作数都从共享内存读取;另外两次——$dV = P^⊤dO$ 和 $dK = dS^⊤Q$——是tensor-shared (TS)操作,其中操作数A从张量内存读取,操作数B从共享内存读取。SS MMA总共从共享内存读取2Md + 3Nd + MN个元素,而TS MMA总共从共享内存读取2Md个元素。在共享内存带宽为每周期128字节且每个元素为2字节(bf16)的情况下,这部分贡献的时间为:

$$T_{\text{smem,MMA}} = \frac{4Md + 3Nd + MN}{64} \text{ cycles.}$$

此外,算法将大小为M × N的中间梯度dS以bf16格式写入共享内存,需要2MN字节或MN/64个周期。大小为M × d的梯度dQ以fp32格式(每个元素4字节)写入共享内存,然后通过TMA读回进行归约,总共需要8Md字节的共享内存流量或Md/16个周期。

因此,总的共享内存访问时间(Tsmem)为:

$$\frac{4Md+3Nd+MN}{64}+\frac{MN}{64}+\frac{Md}{16}\text{ cycles.}$$

指数单元。指数单元计算softmax及其梯度所需的逐元素操作(指数、对数及相关的非线性函数)。反向传播需要在M × N个值上进行指数运算(对应于注意力矩阵S及相关项)。其吞吐量为每周期16次操作,因此指数单元需要的时间为:

$$T_{\exp} = \frac{MN}{16} \text{ cycles.}$$

表3总结了典型瓦片配置M = N = d = 128的分析。共享内存流量时间为3328周期,超过了MMA计算时间(2560周期)和指数单元时间(1024周期),表明共享内存带宽是主要瓶颈,尽管其严重程度不如全局内存流量。这促使我们的核函数设计旨在最大化MMA操作与其他计算之间的重叠,以隐藏共享内存延迟。

3.2.2 用于重叠矩阵乘法和softmax的新流水线

Flash Attention的反向传播执行五次MMA操作,分别对应于S的重计算,以及由QK(产生dQ和dK)和PV(产生dP和dV)引起的两次梯度计算。 在FA-3中,累加器存储在寄存器中,而寄存器是有限资源。这施加了显著的顺序约束,实际上使计算图串行化,即按S, dP, dV, dQ, dK的顺序计算,只有TMA加载操作能显著地乱序执行。除此之外,算法是相似的:它沿着KV序列长度维度迭代,并计算相对于前向传播转置的值,因为这是dV和dK梯度计算需要从张量内存读取其操作数之一的布局。dQ通过原子操作累积。

在FA-4中,TMEM相比FA-3支持了更多的调度方案,这些方案在MMA和非MMA操作之间提供了显著的重叠。 具体来说,就像前向传播一样,我们试图隐藏softmax计算的延迟。在FA-3中,softmax计算与dP的MMA重叠。从上一节我们知道,在Blackwell上,我们至少需要两个MMA操作并行运行。

我们通过使用前一次迭代的dQ和dK MMA来实现这一点。 这需要在加载、MMA、计算和归约操作之间仔细管理共享内存和张量内存资源。特别要注意的是,我们没有足够的张量内存来容纳五个累加器瓦片。最多可以容纳四个128x128元素的瓦片,并且dV和dK是累积的,因此不能共享它们的空间。在我们的实现中,我们让S和P共享一个tmem块(偏移量为0),让dP、dS和dQ共享另一个。我们在图2中展示了FA-4反向传播的计算图。

表3:注意力反向传播的roofline分析,M = N = d = 128。共享内存流量是瓶颈,比MMA计算时间多出约30%。在2-CTA设置下,M = 256, N = d = 128(dQ mma除外,其M = N = 128, d = 256),共享内存流量比MMA计算时间多出约5%。
表3:注意力反向传播的roofline分析,M = N = d = 128。共享内存流量是瓶颈,比MMA计算时间多出约30%。在2-CTA设置下,M = 256, N = d = 128(dQ mma除外,其M = N = 128, d = 256),共享内存流量比MMA计算时间多出约5%。

3.2.3 2-CTA反向传播:减少共享内存流量和全局原子加法

即使通过改进的流水线设计,并且有十分之二的GEMM操作数驻留在张量内存中,共享内存带宽仍然主导着反向传播过程。 在五个GEMM中,剩下的八个BF16操作数从共享内存加载以供给张量核心,而这部分共享内存流量所消耗的周期比张量核心计算多出约30%。为了进一步缓解这一瓶颈,我们使用了Blackwell引入的2-CTA MMA模式,其中输出累加器在M维度上进行分区。使用M=256和N=K=128的MMA瓦片形状,两个CTA作为一个更大的瓦片协同工作:每个CTA加载并暂存操作数B的一半,并只保留自己的累加器切片。

共享内存流量。在反向传播的五个GEMM中,我们使用M=256和N=K=128的MMA瓦片形状,这大致将操作数B的共享内存流量减半。在FlashAttention反向传播中,每个CTA持有一个固定的KV瓦片(在外层循环中跨N个CTA并行),并在内层循环中流式处理M个瓦片。dQ的累积是在外层循环中对KV序列的归约,但2-CTA MMA只分割输出瓦片,不分割归约轴,而dQ MMA的归约维度是N,这自然地在CTA对之间被分割。因此,每个CTA仍然需要对其拥有的行进行完整的归约。为了解决归约轴上的这个冲突,我们使用分布式共享内存(DSMEM)在两个CTA之间交换一半的dS,因为它们在同一个集群中。这种方法重新打包dS,使其沿着非归约轴进行分区,每个CTA拥有其M/2行并持有完整的2N归约。结果,每个CTA的dQ MMA瓦片形状为(M/2, 2N)(2N, d),并在张量内存中累积一个(M/2, d)的瓦片。在2-CTA MMA模式下,S, dP, dV和dK的MMA以M=256的瓦片运行,而dQ使用M=128但双倍归约2N=256。然后我们相对于1-CTA变体重新排序了软件流水线,以隐藏DSMEM的延迟。我们在计算当前瓦片的dP之前,先计算前一迭代瓦片的dQ。dQ瓦片足够小,可以与P一起放入TMEM,重用与S相同的TMEM区域,因此我们不再像1-CTA模式中那样为dP和dQ重用同一个TMEM区域。通过这种新的流水线顺序,我们可以将当前瓦片的逐元素dS计算与前一迭代瓦片的dQ MMA并行进行。图3展示了dQ步骤的分解方式。

在2-CTA反向传播的dQ步骤中,CTA对使用DSMEM交换一半的dS瓦片,因此每个CTA形成一个(M/2 × 2N)的操作数,并可以运行一个具有双倍归约的CTA对UMMA。
在2-CTA反向传播的dQ步骤中,CTA对使用DSMEM交换一半的dS瓦片,因此每个CTA形成一个(M/2 × 2N)的操作数,并可以运行一个具有双倍归约的CTA对UMMA。

dQ原子加法。这种dQ分解的一个补充好处是它将全局原子归约的数量减半。原子更新会引入不确定性,并且由于它们在内层循环的每次迭代中都会发生,因此开销很大。因此,每个CTA只写入dQ瓦片的一半,并执行比1-CTA对应方案少一半的全局原子归约。

3.2.4 确定性反向传播

我们的反向传播核函数由于在全局内存中进行CTA间的归约(通常影响dQ,在GQA情况下影响dK/dV),为梯度计算引入了不确定性。 为了确保可复现性并便于在训练期间进行可靠的调试,我们还提供了一种确定性执行模式。我们采用的标准解决方案是使用信号量锁(semaphore lock)来串行化全局归约。具体来说,每个写入同一个dQ瓦片的CTA必须按照预定义的顺序获取锁,执行其归约,然后通过增加信号量计数器来释放锁。

这种基于锁的方法会影响性能,主要有两个原因:(1)发出内存栅栏(memory fence)以确保信号量写入在设备范围内的可见性(这是正确实现获取-释放语义所必需的),以及(2)当每个CTA等待先前在同一个dQ瓦片上进行归约的CTA完成时引入的停顿。在负载不均衡的情况下,一个简单的CTA顺序选择可能会严重降低性能。通常,我们在头和批次维度上进行CTA混洗(swizzling)以减少停顿(直到L2缓存容量,参见3.3节)。对于因果掩码,我们另外按降序启动KV块,从对角线开始按升序遍历查询块,并按查询块索引的降序来排序dQ归约。这种“最短处理时间优先”(SPT)调度确保没有CTA在其第一次dQ写入时被停顿。

3.3 调度

在许多情况下,例如使用因果掩码或可变序列长度(varlen)时,注意力核函数天然存在负载不均衡——SMs被分配的工作瓦片的主循环长度不同,因为一些工作瓦片需要更多的加载和MMA操作。 此外,我们可以选择SMs处理瓦片的顺序,例如通过定义网格坐标的首选线性化方式。撇开注意力的任何特定特征,我们可以将关于相同并行处理器上最小化制造周期(makespan)的一般性结论应用到我们的情境中。特别地,在FlashAttention-4中,我们使用了经典的“最长处理时间优先”(LPT)调度思想【索引9,Bounds on multiprocessing timing anomalies+1969+SIAM Journal on Applied Mathematics】。我们强调,我们应用此思想的方式适用于所有GPU架构,并且在Hopper GPU上也被验证为对FlashAttention-3的改进。

因果掩码的LPT。标准的注意力网格由(mblocks, heads, batches)给出,并按从左到右的递增顺序计算。但对角线以上的分数被掩码掉了,因此对于固定的头和批次,SMs最终会以从最短到最长的低效方式处理工作瓦片。另一方面,一个简单的LPT顺序也是次优的,因为对于不同的批次,主循环的KV加载不会在L2缓存中命中,而如果所有KV头的大小超过L2缓存容量,首先加载它们可能会导致L2缓存抖动。因此,我们总是将批次作为最外层维度来处理,并对头进行混洗。这意味着我们将头分成不会溢出L2缓存的部分;然后瓦片调度器按每个部分的头、mblocks的逆序、各个部分、最后是批次的顺序遍历网格。特别地,对于MQA或GQA,我们总是在改变mblocks之前遍历每个KV头的所有查询头。经验上,我们验证了这种LPT顺序非常有效;例如,在H200 GPU上测得,对于BF16和头维度128,MHA的FLOPS增益为4-8%,MQA 8的增益为7-14%。

可变序列长度的LPT。对于varlen,我们还必须应对由于批次间差异导致的负载不均衡。例如,在解码工作负载中,不同的批次可能关注不同数量的上下文,而在混合或连续批处理中,一些批次可能是预填充(prefill),而另一些是解码。每个批次的查询和KV序列长度列表通常作为注意力元数据存储在设备上,标准的varlen注意力核函数在运行时按递增顺序处理批次时读取这些整数。然而,给定的批次顺序在负载均衡方面可能是任意次优的——例如,我们可能有较短的方形预填充后面跟着长上下文解码。为了改善这一点,我们可以强制执行LPT顺序,通过启动一个预处理核函数,根据每个工作瓦片的最大执行时间对批次进行排序,并写出额外的元数据,即一个虚拟到实际批次索引的映射,随后该映射将被注意力核函数读回,以便按排序后的顺序遍历批次。这个元数据可以被缓存,因此排序不会带来性能损失。

4 语言和框架

我们将FlashAttention-4完全用嵌入在Python中的CuTe-DSL [21]编写,没有任何CUDA C++组件。 CuTe-DSL编译器将Python源代码转换为PTX,然后使用PTX编译器(ptxas)最终生成汇编代码(SASS)。

通过清晰的抽象实现完全的表达能力。 CuTe-DSL编程模型与CUTLASS C++是同构的,确保了FlashAttention-4在保留低级GPU编程的全部表达能力的同时,也受益于在Python中进行元编程而非C++所带来的生产力提升和快速的JIT编译时间。CuTe-DSL提供了直接访问PTX作为“逃生舱口”,允许开发者实现他们需要的任何功能,而不受框架限制。例如,我们利用自定义的PTX序列来实现一些尚未在CuTe-DSL API中完全暴露的操作(尽管这些将在未来版本中集成),这表明我们的框架不会将开发者限制在一个有限的GPU能力子集内。

通过JIT实现快速编译。 由于复杂的C++模板元编程,编译时间一直是过去FlashAttention实现中的一个瓶颈。通过将CuTe-DSL嵌入Python并采用即时(JIT)编译,FlashAttention-4相比传统的基于C++模板的方法实现了更快的构建时间。如表4所示,与FlashAttention-3相比,FlashAttention-4的编译时间减少了20-30倍。这种快速的迭代周期显著提高了开发者的生产力,使得在核函数开发过程中能够更快地进行实验和调试。

表4:单个核函数的编译时间:FA3(C++模板)和FA4(CuTe-DSL)。通常FA2和FA3需要为不同的注意力变体预编译数百个核函数。
表4:单个核函数的编译时间:FA3(C++模板)和FA4(CuTe-DSL)。通常FA2和FA3需要为不同的注意力变体预编译数百个核函数。

灵活性和可及性。 基于Python的框架在实践中已经展示了其灵活性:开发者已经成功地在FlashAttention-4之上构建了FlexAttention和块稀疏注意力变体,而无需修改核心框架。通过降低入门门槛,我们的方法使得仅有几个月GPU编程经验的研究人员和工程师也能够贡献有意义的扩展,而无需深入掌握C++模板元编程的专业知识。这种可及性加速了创新,并使注意力机制研究社区能够更快地探索新的算法变体。

我们的愿景是提供一个全面的框架,用于构建各种具有顶级性能的注意力变体。 FlashAttention-4不是从零开始实现每一种注意力变体,而是将通用功能分解为独立的、可组合的原语。块稀疏模式、掩码策略、可变序列长度处理和工作调度等操作都被暴露为可以自由组合的正交原语。这种模块化设计确保了优化和新功能能够惠及所有基于该框架构建的注意力实现,同时通过编译成高效的GPU核函数来达到最高性能。

A4 实验环境

  • 硬件配置:

    • GPU: B100 180GB SXM6 (1000W),基准测试在B200 GPU上进行。
  • 软件配置:

    • CUDA 13.1
    • FlashAttention 2.8.3
    • Triton 3.6
    • PyTorch 2.10.0
    • CuTe-DSL 4.4.1
    • cuDNN: 比较了cuDNN 9.13和cuDNN 9.19.1.2。
  • 数据集/工作负载:

    • 基准测试中,序列长度从1k到32k不等,批大小(batch size)设定为使总token数为32k。
    • 模型架构参数:隐藏层维度设为2048,头维度为64或128。还测试了DeepSeek V3架构中使用的(192, 128)配置(16个头,查询维度192,键/值维度128)。
    • 输入数据类型为BF16。
    • 测试了有/无因果掩码(causal mask)的场景。
    • 基准测试流程:进行5次预热运行,然后重复10次基准测试并取平均时间。
    • FLOPS计算方式:前向传播为 4 * seqlen² * head_dimension * number_of_heads(因果掩码时除以2);反向传播为前向的2.5倍。

A5 实验结果

我们评估了FlashAttention-4相对于各种开源和闭源基线的效率。

基准测试总结:我们测量了FlashAttention-4在不同序列长度和头维度下的运行时间,并与PyTorch、FlashAttention-2、Triton、Gluon和cuDNN的实现进行了比较。结果证实,FlashAttention-4比cuDNN 9.13快1.3倍,比Triton快2.7倍。FlashAttention-4的性能高达1613 TFLOPs/s,约占B200 GPU理论最大TFLOPs/s的71%。

5.1 前向传播

前向传播结果如图4和图5所示。FlashAttention-4比cuDNN 9.13快1.1-1.3倍,比Triton快2.1-2.7倍。对于中等到较长的序列(4k及以上),FlashAttention-4在不同的头维度和因果掩码设置下都持续优于所有基线。在因果掩码的情况下,性能增益更大,我们认为这归功于“最长处理时间优先”(LPT)调度器。

图4:B200上(FP16/BF16)头维度为128的前向传播TFLOPS。左:非因果注意力。右:因果注意力。FA4在不同序列长度上比cuDNN 9.13.0快1.1-1.3倍,比Triton快2.1-2.7倍。自我们实现发布以来,新版cuDNN已集成本文所述的许多技术,性能与FA4相似。
图4:B200上(FP16/BF16)头维度为128的前向传播TFLOPS。左:非因果注意力。右:因果注意力。FA4在不同序列长度上比cuDNN 9.13.0快1.1-1.3倍,比Triton快2.1-2.7倍。自我们实现发布以来,新版cuDNN已集成本文所述的许多技术,性能与FA4相似。
图5:在B200上(FP16/BF16),头维度为(192, 128)的因果注意力(常用于DeepSeek V3架构)下,cuDNN与FA4的前向传播TFLOPS比较。
图5:在B200上(FP16/BF16),头维度为(192, 128)的因果注意力(常用于DeepSeek V3架构)下,cuDNN与FA4的前向传播TFLOPS比较。

5.2 反向传播

反向传播结果如图6所示。FlashAttention-4在长序列和因果掩码设置下均实现了稳定的加速,证明了我们2-CTA反向传播的有效性。

我们还在图7中展示了确定性反向传播的性能。我们精心的混洗(swizzling)和调度策略使得确定性反向传播的速度大幅提升,达到了非确定性1-CTA反向传播速度的75%。

图6:B200上(FP16/BF16)头维度为128的反向传播TFLOPS。左:非因果注意力。右:因果注意力。
图6:B200上(FP16/BF16)头维度为128的反向传播TFLOPS。左:非因果注意力。右:因果注意力。
图7:B200上(FP16/BF16)头维度为128的确定性反向传播消融实验。因果注意力设置包括:SPT、LPT(mblock逆序)、LPT和朴素实现(无批次/头混洗)。
图7:B200上(FP16/BF16)头维度为128的确定性反向传播消融实验。因果注意力设置包括:SPT、LPT(mblock逆序)、LPT和朴素实现(无批次/头混洗)。

A5 结论

FlashAttention-4解决了硬件非对称扩展的问题,即张量核心速度过快,导致主要瓶颈转移到共享内存流量和指数运算吞吐量上。这促使我们进行算法与核函数的协同设计来缓解这些限制。我们围绕完全异步的MMA重新设计了流水线,以将softmax计算与更大瓦片的矩阵乘法重叠;引入了软件模拟的指数函数和条件性softmax重缩放来减少非矩阵乘法操作。我们利用张量内存和2-CTA MMA模式来减少共享内存流量。此外,2-CTA模式使得重构全局原子累积成为可能,将全局原子加法的数量减半。

FlashAttention-4完全在嵌入Python的CuTe-DSL中实现,保留了底层控制能力,同时比基于C++模板的核函数编译速度快20-30倍。尽管为Blackwell GPU优化,但随着计算能力继续超越非矩阵乘法单元,其中一些算法可以扩展到其他加速器。

A6 附录

A.1 系统和库的额外细节

我们在B100 180GB SXM6 (1000W)上进行速度基准测试。 我们先进行5次预热运行,然后重复基准测试10次,并取平均时间。

我们通常使用撰写本文时(2025年3月)的最新版本的库。 具体来说,我们使用:

  • CUDA 13.1
  • FlashAttention 2.8.3
  • Triton 3.6
  • PyTorch 2.10.0
  • CuTe-DSL 4.4.1

对于cuDNN,在主论文中,我们与cuDNN 9.13和最新版本cuDNN 9.19.1.2进行了比较。 从版本9.13和9.14开始【索引20,cudnn release notes+2025+https://docs.nvidia.com/deeplearning/cudnn/backend/ latest/release-notes.html】,我们与cuDNN团队合作,将FlashAttention-4中的一些技术融入到cuDNN中,以便我们的工作能惠及尽可能多的从业者。

A.2 反向传播确定性非因果情况

为了完整性,我们还在图8中包含了无因果掩码的确定性反向传播核函数的性能数据,并与有因果掩码的情况并列展示。

图8:B200上头维度为128的确定性反向传播消融实验。左:非因果注意力,带批次/头混洗与朴素实现对比。右:因果注意力——SPT、LPT(mblock逆序)、LPT和朴素实现(无批次/头混洗)。
图8:B200上头维度为128的确定性反向传播消融实验。左:非因果注意力,带批次/头混洗与朴素实现对比。右:因果注意力——SPT、LPT(mblock逆序)、LPT和朴素实现(无批次/头混洗)。