作者/机构: Jintao Zhang∗, Jia Wei∗, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen, 清华大学 {zhang-jt24@mails., jianfeic@, dcszj@}tsinghua.edu.cn
研究动机: Attention机制的效率对于生成模型至关重要,特别是其二次时间复杂度在处理长序列时成为瓶颈。量化是利用GPU中低比特张量核心(Tensor Cores)加速推理的有效方法。NVIDIA Blackwell GPU中新的FP4 Tensor Cores相比FP16提供了显著的性能提升。因此,本文旨在提出一种新颖的FP4 Attention实现,为推理加速提供即插即用的兼容性。此外,训练效率同样重要,但以往工作未曾探索过用于训练大型模型的低比特Attention。为填补这一空白,本文设计了一种可训练的8-bit Attention,以探究其在训练任务中的可行性。
核心挑战:
1. (C1) FP4值域限制: FP4量化面临严重的值域限制(仅15个可表示值),使得传统的逐张量(per-tensor)和逐令牌(per-token)量化方法均不足以保持模型精度。
2. (C2) Attention图P的量化: Attention图P主要由[0, 1]范围内的微小值构成。直接量化到FP4会迫使缩放因子处于极窄的动态范围内,而硬件要求这些因子为FP8数据类型,将这些缩放因子表示为FP8时会导致显著的精度损失。
3. (C3) 8-bit Attention训练: 在训练中使用8-bit Attention时,发现Attention图的梯度对量化误差尤其敏感,导致误差在输入梯度中累积。
本文方法与贡献:
1. 针对FP4推理:
* 方法: 为解决(C1),本文提出对Attention中的两个矩阵乘法(QKᵀ和PV)使用FP4微缩放(microscaling)量化。通过将量化组大小限制为1x16,该方法有效控制了离群值的影响。为解决(C2),本文提出对矩阵P采用一种两级量化方法,首先通过逐令牌量化将每个令牌的范围归一化到[0, 448 * 6],然后应用FP4微缩放量化以提高精度,从而充分利用FP8缩放因子的表示范围。
* 贡献一: 设计了首个用于加速推理的FP4 Attention(SageAttention3),在RTX5090上实现了超过1000 TOPS的性能,相比FlashAttention提速5倍。
图2: 微缩放FP4 Attention的工作流程。
FlashAttention: Attention计算包含两个矩阵乘法和一个softmax计算:$S = QK^⊤$, $P = \text{Softmax}(S)$, $O = PV$。其中$Q, K, V$的形状为$N \times D$,$N$是序列长度,$D$是注意力头的维度。$P, S$的形状为$N \times N$。FlashAttention将$Q$划分为形状为$B_q \times D$的块$\{Q_i\}$,将$K, V$划分为形状为$B_{kv} \times D$的块$\{K_i\}, \{V_i\}$。然后它使用在线softmax来避免为$S$和$P$产生大的内存IO:$S_{ij} = Q_iK^⊤_j$, $P_{ij} = \text{OnlineSoftmax}(S_{ij})$, $O_{ij} = P_{ij}V_j$。
符号表示: 为简化起见,我们省略下标,使用$Q, K, V, S, P, O$表示FlashAttention中的矩阵块,但在算法1、2和3中保留完整的下标符号。
量化: 量化通过将两个矩阵从高比特转换为带缩放因子的低比特来加速矩阵乘法。以矩阵乘法$AB$的INT8量化为例,其中$A$和$B$为FP16数据类型。其公式可表示为:$s_A = \max(|A|)/127$, $\hat{A} = \lceil A/s_A \rfloor$, $s_B = \max(|B|)/127$, $\hat{B} = \lceil B/s_B \rfloor$,其中$\hat{A}, \hat{B}$是INT8类型,其余为FP32类型。然后,$AB \approx \hat{A}\hat{B} \times s_A \times s_B$,这可以通过INT8 Tensor Core加速。量化的粒度由max操作所规约的维度决定。例如,在逐令牌(per-token)量化中,max是沿着矩阵的每一行计算的。在逐块(per-block)量化中,max是在矩阵的一个块上计算的,在本文中指一个FlashAttention块。
本节通过三个关键部分介绍我们的微缩放FP4 Attention:(1) 第3.1节中将微缩放FP4量化应用于Attention的基本工作流程,(2) 第3.2节中针对Attention图的两级量化方法,以及(3) 第3.3节中关键的硬件实现优化。
图3: 两级量化优势分析。(a) 展示了$P_e$的分布。(b) 和 (c) 分别展示了使用直接量化和两级量化时$s_P$的分布。(d) 和 (e) 分别展示了使用直接量化和两级量化时$s_P$和$P_e$的误差。
FP4微缩放量化: 给定一个矩阵$X \in R^{N \times d}$,我们将其量化为FP4数据类型的$\hat{X}$和一个FP8数据类型的缩放因子矩阵$s_X$。具体来说,$X$被划分为多个$X_{ij} \in R^{1 \times n}$的块,其中每个$1 \times n$的块对应一个缩放因子$s_{ij}$。FP4微缩放量化($[\hat{X}, s_X] = \phi(X)$)和反量化($X' = \phi^{-1}(\hat{X}, s_X)$)可以公式化如下。
其中$\lceil \cdot \rfloor$表示FP4舍入。
FP4微缩放量化矩阵乘法: 考虑一个矩阵乘法$AB$,其中$A$和$B$为FP16精度。在RTX5090上,该矩阵乘法的速度约为200 TOPS。相比之下,FP4微缩放矩阵乘法的速度约为1600 TOPS,实现了8倍的加速。FP4微缩放矩阵乘法指令(FP4MM)接受四个输入,即$\hat{A}, s_A, \hat{B}, s_B$,其输出$C$等于$\phi^{-1}(\hat{A}, s_A)$和$\phi^{-1}(\hat{B}, s_B)$之间的矩阵乘法结果:
Attention计算: 我们通过将FP4微缩放量化应用于两个矩阵乘法:$QK^⊤$和$PV$,来加速Attention计算。
需要注意的是,我们的硬件实现基于FlashAttention,因此我们公式中的矩阵$Q, K, P, V$对应于第2节中描述的FlashAttention的分块$Q, K, P, V$。此外,为提高Attention的精度,我们采用了【索引编号5,Sageattention2: Efficient attention with thorough outlier smoothing and per-thread int4 quantization,2025,International Conference on Machine Learning (ICML)】中的平滑$Q$和$K$。完整算法见算法1。
数据类型确定: FP4数据类型有两种选择【索引编号6,Microscaling data formats for deep learning,2023,arXiv preprint arXiv:2310.10537】。第一种是NVFP4,为E2M1数据类型,其量化块大小为$1 \times 16$,缩放因子为E4M3数据类型。第二种是MXFP4,也为E2M1数据类型,但其量化块大小为$1 \times 32$,缩放因子为E8M0数据类型。我们选择NVFP4,因为在Attention量化中,NVFP4的精度远高于MXFP4。经验结果:表1(a)展示了在CogVideoX所有层中使用真实$Q, K, V$时MXFP4和NVFP4的精度。结果表明,NVFP4的精度优于MXFP4。
算法1:微缩放FP4 Attention的实现
1: 输入:矩阵$Q(\text{FP16}), K(\text{FP16}), V(\text{FP16}) \in R^{N \times d}$,块大小$B_q, B_{kv}$。
2: 预处理: $K = K - \text{mean}(K)$ // SageAttention的平滑K。
3: 将$Q$划分为$T_m = N/B_q$个块$\{Q_i\}$;将$K$和$V$划分为$T_n = N/B_{kv}$个块$\{K_i\}, \{V_i\}$;
4: for i = 1 to $T_m$ do
5: $\bar{q}_i = \text{mean}(Q_i)$, $(s_Q, \hat{Q}_i) = \phi(Q_i - \bar{q}_i)$ ; // SageAttention2的平滑Q。
6: for j in [1, $T_n$] do
7: $(s_K, \hat{K}_j) = \phi(K_j^⊤)$, $(s_V, \hat{V}_j) = \phi(V_j)$ ;
8: $S_{ij} = \text{FP4MM}(\hat{Q}_i, s_Q, \hat{K}_j, s_K) + \text{GEMV}(\bar{q}_i, K_j^⊤)$ ; // 平滑Q。
9: $m_{ij} = \max(m_{i,j-1}, \text{rowmax}(S_{ij}))$, $P^e_{ij} = \exp(S_{ij} - m_{ij})$, $l_{ij} = e^{m_{i,j-1}-m_{ij}} + \text{rowsum}(P^e_{ij})$ ;
10: $s_{P1} = \text{rowmax}(P^e_{ij})/(448 \times 6)$, $P^e_{ij} = P^e_{ij}/s_{P1}$, $(s_{P2}, \hat{P}_{ij}) = \phi(P^e_{ij})$; // 两级量化
11: $O_{ij} = \text{diag}(e^{m_{i,j-1}-m_{ij}})^{-1}O_{i,j-1} + \text{FP4MM}(\hat{P}_{ij}, s_{P2}, \hat{V}_j, s_V) \times s_{P1}$
12: end for
13: $O_i = \text{diag}(l_{i,T_n})^{-1}O_{i,T_n}$ ;
14: end for
15: return $O = \{O_i\}$
$P^e$的量化挑战: 对$P^e$应用微缩放FP4量化对Attention精度构成挑战。例如,图12(c)显示直接量化严重降低了输出质量,其结果与全精度输出存在显著差异。我们的分析揭示,问题在于微缩放NVFP4量化要求缩放因子以E4M3 FP8格式表示【索引编号7,Fp8 formats for deep learning,2022,arXiv preprint arXiv:2209.05433】,而非通常用于缩放因子的FP32数据类型。当缩放因子直接转换为E4M3格式时,会造成精度损失。为了更好地理解这种精度损失,我们在图3中分析了$P^e$及其缩放因子的数据分布。由于$P^e$是使用在线softmax【索引编号8,Online normalizer calculation for softmax,2018,arXiv preprint arXiv:1805.02867】计算的,每个微缩放块$P^e_{ij}$中的值落在[0, 1]区间内。因此,缩放因子(缩放因子 = $\max(P^e_{ij})/6$)的范围在0到0.167之间。这个狭窄的范围导致E4M3的可表示范围利用效率低下,增加了精度损失。
两级量化方法: 为了通过充分利用E4M3的范围来减少精度损失,我们提出了一种针对$P^e$矩阵的两级量化方法。具体来说,我们首先将$P^e$的每一行量化到$[0, 448 \times 6]$。然后,我们对量化后的$P^e$应用标准的FP4量化$\phi$。两级量化可以公式化如下:
其中,$P^e$, $P^e_2$, 和$s_{P1}$为FP32数据类型。$s_{P2}$和$s_V$为FP8数据类型。$\hat{P}_2$和$\hat{V}$为FP4数据类型。
经验结果: 如图3所示,我们的两级量化最大化了$s_P$对E4M3范围的利用率,从而减少了$s_P$的数值表示误差和$P^e$的量化误差。更正式的理论分析见附录。表1(b)展示了在CogVideoX各层使用真实$Q, K, V$时,两级量化与朴素直接量化在精度上的对比。结果表明,两级量化提升了精度。
K的置换: 与FP16不同,FP4 MatMul【索引编号9,Parallel Thread Execution ISA Version 8.7,2025,https://docs.nvidia.com/cuda/pdf/ptx_isa_8.4.pdf】中的FP32累加器的内存布局与其操作数A的寄存器布局不同(如图20和19所示)。执行线程洗牌(thread shuffles)以匹配操作数A的布局会降低核函数性能。我们的解决方案是通过置换P块的列来转换累加器的布局(图21)。为保持正确的矩阵乘法,我们相应地重排K的列,这一操作可以与量化核函数融合。
重用shuffle: $P^e$的核内微缩放量化需要在连续的16个行元素中找到最大值。然而,如图21所示,这16个元素分布在四个线程中,需要先进行线程内最大值规约,再进行线程间洗牌,这会显著减慢核函数速度。我们通过将量化与在线softmax融合来优化此过程,因为在线softmax也计算行最大值。首先,我们计算$S$中16个元素的最大值,并在后续的softmax最大值规约中重用它。这种融合将冗余的洗牌和最大值操作减少了50%,带来了约10%的整体核函数加速。
生产者warp尾声: 在传统的warp专用核函数中,消费者warp通常处理矩阵乘法和存储操作,而生产者仅加载输入,并通过消费者之间的乒乓调度实现阶段重叠【索引编号10,Efficient gemm in cuda,2025,https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html】。然而,寄存器限制使得这种方法在我们的FP4 Attention核函数中不可行。作为替代,我们在生产者warp之间实现乒乓调度:当一个生产者为下一次矩阵乘法操作加载输入时,另一个生产者同时将输出存储到全局内存,而消费者warp仅负责将矩阵乘法结果从寄存器传输到共享内存。这种新颖的设计在寄存器限制下实现了矩阵乘法和全局内存存储的重叠,从而提高了吞吐量。
诸如FlashAttention3和SageAttention等低比特量化Attention工作仅用于推理。本节中,我们提出了一种用于训练的INT8 Attention,名为SageBwd,它将Attention中的七个矩阵乘法中的六个量化为INT8,在微调任务中实现了无性能下降。
算法2:8-bit Attention的前向传播
1: 输入:FP16矩阵$Q, K, V \in R^{N \times d}$,以及块大小$B_q, B_{kv}$。
2: 将$Q$划分为$T_m = N/B_q$个块$\{Q_i\}$;将$K$和$V$划分为$T_n = N/B_{kv}$个块$\{K_i\}, \{V_i\}$;
3: 量化:$\{s_Q, \hat{Q}_i\} = \{\psi(Q_i)\}, \{s_K, \hat{K}_i\} = \{\psi(K_i^⊤)\}, \{s_V, \hat{V}_i\} = \{\psi(V_i)\}$;// 逐块量化。
4: for i = 1 to $T_m$ do
5: $O_i \in R^{B_q \times D} = (0), L_i \in R^{B_q} = (0), m_i \in R^{B_{kv}} = (0)$;
6: for j in [1, $T_n$] do
7: $S_{ij} = \text{MM}(\hat{Q}_i, \hat{K}_j) \times s_Q \times s_K$;
8: $m_{ij} = \max(m_{i,j-1}, \text{rowmax}(S_{ij})), P^e_{ij} = \exp(S_{ij} - m_{ij}), l_{ij} = e^{m_{i,j-1}-m_{ij}} + \text{rowsum}(P^e_{ij})$;
9: $s_P = \exp(\text{rowmax}(S_{ij}) - m_{ij})/127, \hat{P}_{ij} = P_{ij}/s_P$;// 逐令牌量化。
10: $O_{ij} = \text{diag}(e^{m_{i,j-1}-m_{ij}})^{-1}O_{i,j-1} + \text{MM}(\hat{P}_{ij}, \hat{V}_j) \times s_P \times s_V$
11: end for
12: $O_i = \text{diag}(l_{i,T_n})^{-1}O_{i,T_n}$;
13: $L_i = m_{i,T_n} + \log(l_{i,T_n})$;
14: end for
15: return $O = \{O_i\}, L = \{L_i\}$;
前向传播中的矩阵乘法: Attention的前向传播中有两个矩阵乘法:
P的逐令牌量化: 遵循【索引编号11,Sageattention: Accurate 8-bit attention for plug-and-play inference acceleration,2025,The International Conference on Learning Representations】,我们对$QK^⊤$应用平滑K和逐块INT8量化。然而,对于$P^e V$,使用静态缩放因子1/127的静态逐块INT8量化对$P^e$是不准确的【索引编号11,Sageattention: Accurate 8-bit attention for plug-and-play inference acceleration,2025,The International Conference on Learning Representations】。幸运的是,我们发现对$P^e$应用逐令牌INT8量化,并对$V$应用逐块INT8量化可以提高Attention的精度。此外,我们通过重用在线softmax计算中的全局和局部最大值,消除了对$P$进行显式max操作的需求(算法2第9行)。前向传播的算法如算法2所示。
INT8逐块量化: 鉴于我们在可训练Attention中广泛使用INT8逐块量化,我们将其过程形式化如下。对于每个FlashAttention块$X$,量化过程$[s_X, \hat{X}] = \psi(X)$可以公式化为:
反向传播中的矩阵乘法: Attention的反向传播中有五个矩阵乘法:
dOVᵀ的精度保持: 我们观察到,是否对$dOV^⊤$进行量化对$Q, K$的梯度精度有显著影响。这是因为$dOV^⊤$的精度直接决定了$dP$和$dS$的精度(见算法3中的计算依赖关系)。在FlashAttention的反向传播中,$dS$的精度损失会沿着序列长度在循环过程中不断累积误差到$dQ$和$dK$中,意味着序列越长,误差累积越大。因此,我们将$dOV^⊤$保持在FP16,同时使用INT8逐块量化加速其他四个矩阵乘法。反向传播的算法如算法3所示。
经验结果: 表1(c)显示了在对$dOV^⊤$进行和不进行量化时$dQ$的精度。我们发现,当保持$dOV^⊤$为FP16时,$dQ$的精度显著提高。
算法3:8-bit Attention的反向传播
1: 输入:来自前向传播的$\{s_Q, \hat{Q}_i\}, \{s_K, \hat{K}_i\}, \{s_V, \hat{V}_i\}, O, \{L_i\}$,$dO \in R^{N \times d}$,以及块大小$B_q, B_{kv}$;
2: $D = \text{rowsum}(dO \circ O)$,将$D$划分为$T_m = N/B_q$个块$\{D_i\}$;
3: for j = 1 to $T_n$ do
4: for i in [1, $T_m$] do
5: $S_{ij} = \text{MM}(\hat{Q}_i, \hat{K}_j) \times s_Q \times s_K$;$P_{ij} = \exp(S_{ij} - L_i)$;
6: $[s_P, \hat{P}_{ij}] = \psi(P_{ij})$, $[s_{dO}, d\hat{O}_i] = \psi(dO_i)$;// INT8逐块量化。
7: $dV_j \leftarrow dV_j + \text{MM}(\hat{P}_{ij}^⊤, d\hat{O}_i) \times s_P \times s_{dO}$;
8: $dP_{ij} = \text{MM}(dO, V_j^⊤)$;// 保持在FP16。
9: $dS_{ij} = P_{ij} \circ (dP_{ij} - D_i)$;$[s_{dS}, d\hat{S}_{ij}] = \psi(dS_{ij})$;// INT8逐块量化。
10: $dQ_i \leftarrow dQ_i + \text{MM}(d\hat{S}_{ij}, \hat{K}_j) \times s_{dS} \times s_K$;
11: $dK_j \leftarrow dK_j + \text{MM}(d\hat{S}_{ij}^⊤, \hat{Q}_i) \times s_{dS} \times s_Q$;
12: end for
13: end for
14: return $dQ, dK, dV$;
表1: 不同量化策略的精度消融实验。
主要结果概述: SageAttention3在RTX5090上的速度比FlashAttention快5倍,比xformers快11倍,并在各种模型上保持了端到端的指标。SageBwd在RTX4090上的速度比FlashAttention快1.67倍,比xformers快3倍,并在微调任务中实现了无损性能。
核函数速度:
* SageAttention3: 在RTX5090上,相较于FlashAttention2实现了4-5倍的加速,相较于xformers实现了8-11倍的加速 (图4, 5)。
* SageBwd: 在RTX4090上,其前向+反向传播速度最高比FlashAttention2快1.67倍,并且比Triton实现的FlashAttention2和xformers有更高的加速比 (图6, 7)。
图4: SageAttention3与基线方法的速度比较 (RTX5090, headim=128)。
图5: SageAttention3与基线方法的速度比较 (RTX5090, headim=64)。
图6: SageBwd与基线方法的速度比较 (RTX4090, headim=128)。
图7: SageBwd与基线方法的速度比较 (RTX4090, headim=64)。
端到端指标损失:
* SageAttention3: 在各种模型上几乎没有引入端到端的质量损失 (表2)。
* SageBwd:
* 微调任务: 在Qwen2.5 (3B) 和 Llama3.2 (1B) 上的微调损失曲线与BF16完美对齐 (图8 b-e)。在多个测试集上的评估结果表明,SageBwd实现了与BF16相同的性能 (表3)。
* 预训练任务: 在使用Llama (400M) 模型进行的预训练任务中,SageBwd虽然能够收敛,但收敛速度相对较慢,限制了其在预训练任务中的应用 (图8 a)。
表2: 在各种模型上的端到端指标比较。
表3: 在Qwen2.5和Llama3.2模型上使用8-bit Attention的微调结果。
图8: BF16和8-bit Attention的预训练和微调损失曲线。
可视化示例:
* 图9展示了在HunyuanVideo和Stable-Diffusion3.5上使用SageAttention3的生成效果对比,结果表明SageAttention3保持了完整的生成质量。更多示例见附录中的图10, 11, 13, 14。
图9: 在HunyuanVideo(左)上的视频生成和在Stable-Diffusion3.5(右)上的图像生成的可视化示例。
端到端加速:
* 推理 (SageAttention3): 在RTX5090上,HunyuanVideo和CogVideoX的端到端推理生成速度分别提升了约3倍和2.4倍 (表4a)。
* 训练 (SageBwd): 在RTX4090上,使用8K/16K token的微批次,Llama (1B) 的训练速度提升了约1.15倍 (表4b)。
(a) 使用SageAttention3的推理延迟。
(b) 使用SageBwd的一次迭代训练延迟。
表4: 使用SageAttention3和SageBwd的端到端加速性能。
本文做出了两大关键贡献。首先,我们设计了SageAttention3,这是首个用于推理加速的微缩放FP4 Attention,在RTX5090上实现了1038 TOPS的性能,比RTX5090上最快的FlashAttention快5倍。实验表明,SageAttention3可以在不降低端到端质量指标的情况下加速各种模型。其次,我们引入了首个可训练的8-bit Attention(SageBwd)用于训练加速,并探讨了其在训练任务中的可行性。我们发现,8-bit Attention在微调任务中可以达到无损性能,但目前在预训练任务中存在一些局限性。
未来工作:
1. 尽管SageBwd比FP16实现更快,但其当前速度与理论上限之间仍有差距。这可能是由次优的Triton核函数实现造成的,我们计划进一步优化。
2. 更重要的是,研究低比特Attention在预训练任务中的应用是一个有前景的研究方向,值得探索。
图10和图11展示了图像生成任务的额外可视化对比示例。图13和图14展示了视频生成任务的更多可视化对比示例。
图10: 在Stable-Diffusion3.5上的图像生成可视化示例。
图11: 在Flux上的图像生成可视化示例。
图12: 来自CogVideoX的不同$P^e$缩放策略的可视化比较。
图13: 在CogVideoX上的视频生成可视化示例。
图14: 在HunyuanVideo上的视频生成可视化示例。
图15和图16显示了SageBwd的前向核函数速度。图17和图18显示了SageBwd的反向核函数速度。SageBwd在前向传播中比FlashAttention快2倍。SageBwd在反向传播中比FlashAttention快1.2~1.6倍。
图15: SageBwd与基线方法的前向速度比较 (RTX4090, headim=128)。
图16: SageBwd与基线方法的前向速度比较 (RTX4090, headim=64)。
图17: SageBwd与基线方法的反向速度比较 (RTX4090, headim=128)。
图18: SageBwd与基线方法的反向速度比较 (RTX4090, headim=64)。
数据集: 文生视频模型使用open-sora [49]提示集进行评估。文生图模型在COCO [50]标注上进行评估。语言模型在GSM8K [23]、DROP [24]、MMLU [25]和HELLASWAG [26]数据集上进行评估。
端到端指标: 对于文生文模型,我们使用准确率(Acc.)和F1分数(F1)。对于文生视频模型,我们在五个指标上评估生成视频的质量:CLIPSIM和CLIP-Temp (CLIP-T) [51]用于衡量文本-视频对齐度;(VQA-a)和(VQA-t)分别用于评估视频的美学和技术质量;Flow-score (FScore)用于时间一致性[52]。对于文生图模型,从三个方面评估生成的图像:FID [53]和sFID [54]用于保真度评估,Clipscore (CLIP) [55]用于文本-图像对齐,以及ImageReward (IR) [56]用于人类偏好。
准确率指标: 我们使用三个指标来评估量化Attention输出$O'$与全精度Attention输出$O$的准确性:首先,我们将$O'$和$O$展平为形状为$1 \times n$的向量。然后,余弦相似度:$CosSim = \sum OO'/\sqrt{\sum O^2}\sqrt{\sum O'^2}$,相对L1距离:$L1 = \sum |O - O'|/\sum |O|$,均方根误差:$RMSE = \sqrt{(1/n)\sum(O - O')^2}$。
超参数: 对于预训练任务,我们使用一个400M的模型,隐藏层大小为1024,20层,中间层大小为3072,16个注意力头。训练使用1e-3的学习率,在1000个预热步骤后线性衰减,每步处理2M个token。对于微调任务,我们在GSM8K数据集上以32的批量大小训练700步,在MMLU、DROP和HELLASWAG数据集上以128的批量大小训练,学习率为3e-5,有100个预热步骤并线性衰减。
图19: FP4操作数A寄存器布局 - 第0和8行,线程0-3,条目0-15。
图20: FP32累加器寄存器布局 - 第0和8行,线程0-3,条目0-15。
图21: 置换后的FP32累加器寄存器布局 - 第0和8行,线程0-3,条目0-15。
表5–10显示了Qwen2.5 (1.5B), Qwen2.5 (3B) 和 Llama3.2 (3B) 在四个数据集上使用五个不同随机种子的微调结果。平均值和标准差表明,SageBwd在各种随机种子下与BF16高度一致。
表5: SageBwd与BF16在Qwen2.5 (1.5B)上不同种子在GSM8K和DROP上的性能比较。
表6: SageBwd与BF16在Qwen2.5 (1.5B)上不同种子在MMLU和HellaSwag上的性能比较。
表7: SageBwd与BF16在Qwen2.5 (3B)上不同种子在GSM8K和DROP上的性能比较。
表8: SageBwd与BF16在Qwen2.5 (3B)上不同种子在MMLU和HellaSwag上的性能比较。
表9: SageBwd与BF16在Llama3.2 (1B)上不同种子在GSM8K和DROP上的性能比较。
表10: SageBwd与BF16在Llama3.2 (3B)上不同种子在MMLU和HellaSwag上的性能比较。
引用段落: 第3.1节,Attention计算
...为提高Attention的精度,我们采用了SageAttention2 [5]中的平滑Q和K...引用段落: 第3.1节,数据类型确定
FP4数据类型有两种选择[6]。引用段落: 第3.2节,$P^e$的量化挑战
...微缩放NVFP4量化要求缩放因子以E4M3 FP8格式表示[7],而非通常用于缩放因子的FP32数据类型。引用段落: 第3.2节,$P^e$的量化挑战
引用段落: 第3.3节,K的置换
与FP16不同,FP4 MatMul [9]中的FP32累加器的内存布局与其操作数A的寄存器布局不同...引用段落: 第3.3节,生产者warp尾声
...并通过消费者之间的乒乓调度实现阶段重叠[10]。引用段落: 第4.1节,P的逐令牌量化
遵循SageAttention [11],我们对QK^⊤应用平滑K和逐块INT8量化。引用段落: 第4.1节,P的逐令牌量化
然而,对于P^e V,使用静态缩放因子1/127的静态逐块INT8量化对P^e是不准确的[11]。