MX+: Pushing the Limits of Microscaling Formats for Efficient Large Language Model Serving
MX+: Pushing the Limits of Microscaling Formats for Efficient Large Language Model Serving
作者: Jungi Lee, Junyong Park, Soohyun Cha, Jaehoon Cho, Jaewoong Sim
机构: Seoul National University
A1 主要贡献
核心问题: 高效服务大型语言模型(LLM)因其巨大的计算和内存资源需求而面临重大挑战。尽管低比特量化方案和降精度数据格式被提出来缓解这些开销,但它们通常需要对软件代码库进行大量修改,或与传统表示相差太远,难以在不同计算平台上广泛采用。近期,业界合作推出了微缩放(Microscaling, MX)数据格式,它基于块浮点(BFP)表示,但其在LLM服务中的有效性,特别是超低比特(如4-bit MX格式,MXFP4)版本,因无法有效处理LLM激活值中的离群值(outliers)而导致模型性能显著下降。
研究目标: 本文旨在通过全面分析,突破行业驱动的BFP格式变体的极限,以实现高效的LLM服务。具体目标是解决超低比特BFP变体(特别是MXFP4)在处理LLM激活值中离群值时遇到的性能瓶颈问题。
创新点 (MX+): 为了有效解决BFP中的离群值问题,本文提出了MX+格式,这是一种为无缝集成到MX格式而设计的、成本效益高且非侵入式的扩展。MX+的设计基于两个关键洞察:
1. 离群值的自然识别: 在MX格式中,块内最大幅值元素的指数被用来确定共享缩放因子,这使得我们可以在没有额外计算或硬件逻辑的情况下,自然地识别出离群值元素及其在块内的位置。
2. 冗余指数的再利用: 对于MX块中的离群值元素,我们不需要存储其自身的指数,因为它总是被设置为元素数据类型可表示的最大指数。这使我们能够将其指数字段重新用作扩展的尾数(mantissa),从而在低比特MX格式(如MXFP4,只有一两位尾数)中极大地提高离群值元素的精度。
主要贡献总结:
* 提出MX+: 一种对MX格式的非侵入式扩展。MX+在无需额外用户努力或复杂性的情况下,改进了对离群值的表示,即使在低比特量化下也能实现较高的模型性能。
* 软件集成与评估: 将MX+集成到现有软件库中,并证明其在LLM推理中仅引入边际的减速,无需硬件修改,尤其是在token生成主导推理时间的情况下。
* 硬件设计: 提出了一种硬件设计,能够在Tensor Cores内直接进行MX+计算,而无需对点积流水线进行侵入式修改,从而在实现更高模型精度的同时,提供接近MX的性能。
A3 背景知识
通过将权重和激活从较高精度映射到较低精度的粗粒度表示,可以使用更简单但吞吐量更高的计算单元来加速计算,同时也能更有效地利用内存带宽和容量。例如,广泛使用的均匀对称整数 quantization 方案使用缩放因子 $s$ 将一组 $k$ 个浮点数 $x_f$ 映射到 $b$ 位整数 $x_q$,如下所示:
对于整数 quantization,反量化涉及将缩放因子与整数值相乘,$s \cdot x_q$。
块浮点(BFP)格式,例如MSFP [6],与传统的整数 quantization 有些相似,但缩放因子($S$)被限制为2的幂。这种限制使得硬件能够在更细的块($k$)粒度上高效地管理缩放或重新缩放,从而能够更准确地表示原始的权重和激活张量。图1比较了几种行业提出的基于块的数据格式,我们将在下面简要解释。
微软浮点(Microsoft Floating Point)。微软浮点(MSFP)是BFP格式的一种变体,已在Project Brainwave [14] 中部署。一个MSFP块包含 $k$ 个元素,每个元素都有自己的符号位和尾数位,以及一个由块中所有元素共享的指数。在MSFP的一个典型用例中 [6],例如,一个浮点张量中的16个元素被分组到一个块中,共享一个8位的指数,该指数被设置为块内绝对值最大元素的指数。每个元素的尾数是通过将原始浮点值右移共享指数与其原始指数之差得到的;因此,MSFP尾数中没有隐含的前导位。请注意,MSFP格式根据其总比特宽度命名;例如,MSFP12只有四位用于符号和尾数,导致每个元素的平均比特宽度为4.5位。
共享微指数(Shared Microexponents)。共享微指数(SMX)数据格式 [7] 是最近的一项提议,与MSFP类似,其缩放因子由一组元素共享,并限制为2的幂。然而,SMX采用多级缩放方法,与MSFP中的单级缩放形成对比。在其典型的两级缩放用例中 [7],一组16个元素($k_1=16$)共享一个一级缩放因子 $S$,这是一个8位的共享指数,而该组内的元素对($k_2=2$)形成一个子组,共享一个二级缩放(子缩放)因子 $S_i$,由每个子组的一个1位共享微指数表示。
微缩放格式(Microscaling Formats)。微缩放(MX)格式 [47] 是另一项由多家行业公司合作开发的最新提议,旨在建立开放和可互操作的数据格式。一个MX块由32个元素($k=32$)和一个共享的缩放因子 $S$ 组成,该缩放因子是一个8位的共享指数,类似于MSFP或SMX中使用的指数。然而,与MSFP或SMX不同的是,元素数据类型可以从五种浮点和一种整数编码中选择,如表1所示。
表1: 具体的MX兼容格式。
整数数据类型(即MXINT8)使用二进制补码编码,并有一个隐含的缩放因子$2^{-6}$。对于MXFP格式,MX块中的每个私有元素都有自己的指数位,这使得每个元素实际上都是一个浮点数。在MX中,共享指数和相应的缩放因子 $X$ 可以计算如下:
其中$e_{\text{max}}$是元素数据类型可表示的最大指数值。例如,在MXFP4中,元素数据类型是FP4,每个元素都有一个单独的2位指数,指数偏差为1。因此,$e_{\text{max}}$变为2(即$11_2-1$)。
A3 关键Observation/设计原则
本节我们首先比较了在使用第2节中讨论的BFP变体时各种LLM的模型性能。然后,我们研究了使用极低比特格式导致性能下降的根本原因,并讨论了如何更好地利用低比特BFP进行LLM服务。
3.1 BFP变体的模型性能
图2显示了使用WikiText-2数据集,序列长度为2048时,各种LLM在不同行业驱动的BFP格式下的困惑度。基线(B)使用Bfloat16(BF16)作为默认数据格式,并以BF16执行矩阵乘法和元素级操作,除了softmax使用FP32。对于BFP变体的评估,我们遵循先前工作 [【7,With Shared Microexponents, A Little Shifting Goes a Long Way,2023,ISCA】,【52,Microscaling Data Formats for Deep Learning,2023,arXiv】] 中概述的计算流程;BF16张量被转换为MSFP、SMX和MX进行矩阵乘法,而元素级操作使用与基线相同的精度(即BF16或FP32)。我们选择了平均每元素比特数与MXFP4(L)、MXFP6(M)和MXFP8(H)相似的MSFP和SMX格式。这些BFP变体的平均比特宽度分别在 $4 \le L \le 4.5$,$6 \le M \le 6.5$ 和 $8.25 \le H \le 9$ 的范围内。
总的来说,MX在相似比特宽度下优于或与其它BFP变体相当。对于高比特(H)格式,所有BFP变体的性能都接近基线;虽然MXFP8的困惑度略高于SMX9或MSFP16,这是由于其平均每元素比特数较低(8.25,而其他格式为9和8.5)以及使用了SMX或MSFP不支持的保留NaN表示。然而,在中等比特(M)格式中,SMX6和MSFP14开始偏离基线,使其在LLM服务场景中效果较差,而MXFP6仍然接近基线。这是因为,与SMX或MSFP中指数由块中部分或所有元素共享不同,MXFP中的每个元素除了共享指数外,还有自己的指数,从而可以实现更细粒度的值表示。此外,MXFP对正常数采用隐含的前导1,并以类似于IEEE-754浮点格式的方式定义了次正常数,从而与其他格式相比具有更大的有效比特宽度。然而,当使用低比特MX格式(即MXFP4)时,即使是MX,困惑度也开始偏离基线。尽管它仍然明显优于SMX4和MSFP12,但这使得低比特MX格式在LLM服务中的实际应用性降低,尽管它具有节省大量带宽和提高计算效率的潜力。
3.2 低比特MX格式分析
为了理解低比特MX格式导致模型性能下降的根本原因,我们对MXFP4进行了进一步分析。我们首先评估了仅将激活张量(A)或权重张量(W)量化为MXFP4,而另一方使用BF16时的困惑度。图3显示,量化权重(A-BF16, W-MXFP4)导致的困惑度增加可以忽略不计,而量化激活(A-MXFP4, W-BF16)则会显著降低模型性能。这表明,尽管MX采用细粒度缩放(即每块32个元素)来减少张量中离群值的影响,但低比特MX对于激活张量并不能有效缓解这个问题。
为了更深入地了解根本原因,我们检查了激活张量中的MX块。图4(a)显示了Llama-3.1-8B激活幅度的热图,而图4(b)则展示了两个样本块的原始BF16值及其MXFP4和MXFP6表示。在这里,我们将MX块中的绝对最大值称为块最大值(Block Max, BM),而其他值则称为非块最大值(Non-Block Max, NBM)。
我们观察到,BM显著大于NBM的块(例如,上方的样本块)——主要是由于MX块中存在离群值——往往会因两个原因表现出高量化误差。首先,由于MXFP4只为尾数分配了1位,因此大幅度的BM在量化时容易与其原始值产生巨大偏差(➊)。其次,BM的指数决定了整个块的缩放因子,如公式1所示。因此,当BM很大时,共享缩放因子也会变大,这迫使其他较小的元素(即NBM)由于共享缩放而表示得不那么精确。例如,大多数NBM在被共享缩放因子除后被量化为零(➋)。相比之下,没有离群值的块(例如,下方的样本块)自然具有相对较低的量化误差。
图5显示了由量化误差最大的元素或BM元素对MSE(均方误差)的贡献。我们可以看到,在每个MX块中更准确地表示BM元素可以显著减少一部分量化误差。虽然理论上MX块中的32个元素中的任何一个都可能成为该块误差最大的元素,但在每个块中识别误差最大的元素会增加计算复杂性,而带来的好处不大,因为BM元素通常是最大的贡献者。
总而言之,虽然理想情况是精确表示块中的所有32个元素,但由于尾数位的数量有限,这在MXFP4中可能不可行。相反,我们的分析表明,在使用低比特MX格式时,仅专注于更好地表示BM元素就可以显著帮助提高模型性能。
A2 方法细节
4 MX+: 增强MX格式
正如第3.2节所讨论的,MX块中的最大幅值在块内元素中通常经历最高的量化误差,这在MX块包含离群值时会严重损害模型性能。在本节中,我们提出了MX+,一种为无缝集成到MX格式中而设计的成本效益高的扩展,用于低比特LLM服务。
4.1 MX+设计
我们的MX+设计围绕三个关键考虑。首先,它不应干扰MX的设计目标;该扩展需要完全在转换内核或硬件单元内部进行管理,以便与各种框架无缝集成,而无需终端用户的额外努力。其次,该扩展需要有效处理块中的离群值以减轻量化误差,同时与MX规范 [47] 保持一致。第三,该扩展不应在存储和运行时延迟方面引入大的开销。
MX+建立在两个关键洞察之上。首先,每个MX块中的BM元素不需要存储其自身的指数,因为只要将低于阈值的极小幅值数刷新为零,它的指数总是被设置为元素数据类型可表示的最大指数。这使我们能够安全地将指数字段重新用作扩展的尾数,以更精确地表示BM。其次,在从更高精度转换为MX以计算共享缩放因子的过程中,BM元素也自然地被识别出来。因此,识别每个MX块中的BM元素不需要额外的计算。
图6展示了使用图4中呈现的样本块的MXFP4和MXFP4+的二进制编码示例。如图所示,在MXFP4中,BM元素的指数字段总是设置为可表示的最大值(即$11_2$)。MXFP4+将这些位重新用于存储额外的尾数,从而提供对原始BM值更准确的表示。请注意,MX+不改变共享缩放因子。
与先前的工作 [52] 类似,我们将极小幅度的值刷新为零,以简化转换并启用MX+扩展。具体来说,如果BM的指数($\lfloor \log_2(\text{BM}) \rfloor$)小于或等于$-127 + e_{\text{max}}$,我们将块中的所有元素设置为零。这是因为,在这种情况下,共享指数会被钳位在其下限-127,这导致元素数据类型的指数字段被设置为小于$e_{\text{max}}$。为了表示这种情况,我们通过保留一个特殊值来扩展共享指数编码:一个偏置为零的共享指数表示块中的所有元素都为零。
4.2 MX+的数据布局
图7展示了MX+的三种可能类型的数据布局:MXFP4+、MXFP6+和MXFP8+,它们分别是MXFP4、MXFP6(E2M3)和MXFP8(E4M3)的扩展。每个MX块被分配额外的8位,其中5位用于存储块内BM元素的索引。剩下的3位被保留,可用于进一步的优化或支持未来定义块大小不为32个元素的MX规范。
NBM值被转换为传统的MX元素数据类型,如E2M1、E2M3和E4M3,而BM值则用更多的尾数位存储,如E0M3、E0M5和E0M7。我们不显式存储BM的指数,因为它将永远是给定元素数据类型的最大值(即对于E2M1和E2M3是2;对于E4M3是8),如公式1所示。因此,虽然使用与NBM相同的比特宽度,但BM实际上被表示为MXFP4+、MXFP6+和MXFP8+的E2M3、E2M5和E4M7。
请注意,由于所有元素使用相同的比特宽度,MX+不会导致非对齐的内存访问。这种设计在计算和内存成本方面也只产生微不足道的开销,因为BM在转换为MX格式的过程中已经被识别出来。额外的比特仅使平均比特宽度增加了0.25(例如,对于MXFP4,从4.25增加到4.5)。与共享缩放因子 [47] 类似,索引元数据不需要与元素数据或共享缩放因子连续存储。它也可以为重复值进行压缩或剪枝。
4.3 保留位的潜在用途
虽然MX+通过更精确地表示BM元素大大减少了块级量化误差,但NBM元素也对量化误差有贡献。如前所述,由于共享缩放因子由BM决定,NBM的表示可能比它们使用自己的缩放因子时更不精确。为了展示利用保留位的一个例子,我们还考虑了MX+的一个变体,称为MX++,并评估了其准确性。
MX++通过利用保留位将NBM的共享缩放因子与BM的共享缩放因子解耦,通常使NBM能够映射到比MX+更精细的量化网格。具体来说,NBM使用一个小于或等于BM共享缩放因子的共享缩放因子,它们共享指数之间的差异被编码在图7中的三个保留位中。然而,直接将公式1中的共享指数计算应用于NBM可能会增加量化误差,因为NBM元素在缩放后可能会饱和到元素数据类型的最大幅值。因此,我们定义了NBM的最小可行共享指数 $e$ 以避免饱和,如下所示:
其中$max_2$识别给定MX块中的第二大指数。如果没有1的偏移量,前两项将表示没有BM的MX块的共享指数(公式1),这可能会引入额外的误差。考虑图6中的示例元素。如果没有偏移量,$e$ 等于-3($= -1 - 2$,其中-1是0.99的指数),并且值0.99被缩放到7.92($= 0.99 \div 2^{-3}$)并饱和到MXFP4中可表示的最大值6.0。然而,有了1的偏移量,该值被缩放到3.96并保持在可表示的范围内。
NBM的最终共享指数是通过应用裁剪函数$clip(e, \{e_{min},e_{max}\})$来确定的:$shared\_exp_{new} = clip (e, \{shared\_exp - 7, shared\_exp\})$,其中 $shared\_exp$ 和 $shared\_exp_{new}$ 分别表示MX++中BM和NBM的共享指数。下限确保了与BM共享指数($shared\_exp$)的差异能用3位表示。上限解决了BM和最大NBM的指数相同,因此由于偏移量导致 $e$ 超过 $shared\_exp$ 的情况。回到之前的例子,$shared\_exp_{new}$为-2使得NBM值-0.39能够缩放到-1.56并映射到-1.5,而之前它被量化为零,共享指数为1。
5 在GPU上对MX+进行软件集成
MX格式正通过软件和硬件支持日益集成到现有的DNN加速系统中。在缺少用于MX中低精度元素数据类型的计算单元的系统中,MX块通常被转换为硬件支持的更高精度格式[【21,Intel Unleashes Enterprise AI with Gaudi 3, AI Open Systems Strategy and New Customer Wins,2024,https://www.intc.com/news-events/press-releases/detail/1689/intel-unleashes-enterprise-ai-with-gaudi-3-ai-open-systems】,【57 ,Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations,2019,MAPL】,【61,Ladder: Enabling Efficient Low-Precision Deep Learning Computing through Hardware-aware Tensor Transformation,2024,OSDI】]。例如,存储在MXFP4中的数据可以通过软件支持转换为FP16,以便在Intel Granite Rapids上进行计算[21]。在这种情况下,MX+也可以通过对转换内核进行微小修改来轻松支持,如下所示:
其中$i$和$b$分别表示块内的元素索引和给定元素数据类型的指数偏差。此外,$s_i$,$m_i$和$e_i$分别表示输入元素$i$的符号、尾数和私有指数。请注意,尾数$m_i$在MX+中的BM和NBM元素之间的捕获方式不同。例如,在MXFP4+中,BM元素有三个有效位的尾数,而NBM元素只有一个位。
当系统配备了原生支持MX兼容格式的计算单元时——例如最近发布的NVIDIA Blackwell GPU中的Tensor Cores [41]——MX块中的元素数据可以直接在计算单元内处理,无需格式转换。在本节中,我们提出了一种将MX+扩展集成到支持MX精度格式的GPU系统中的方法,而无需任何硬件修改。为清晰起见,我们专注于激活以MXFP4+表示,权重以MXFP4表示的场景,这种配置实现了与两者都使用MXFP4+时相当的模型性能,我们将在第7.2节中讨论。然而,所提出的方法也适用于两个操作数都使用MXFP4+或其他格式(如MXFP6+)的情况。
5.1 在GPU上处理BM的挑战
NVIDIA的Tensor Core为预定义的一组矩阵块形状($m \times n \times k$)以及特定的输入和输出数据类型执行矩阵乘法累加(MMA)操作($D = A \times B + C$)。Tensor Core MMA操作通过PTX指令(如wmma.mma、mma和wgmma.mma_async)暴露给程序员,这些指令被翻译成设备特定的机器代码(即SASS指令),如HMMA(半精度)和IMMA(整数)。我们在本节中基于以下MXFP4 mma PTX指令进行讨论,不失一般性[46]。
该指令操作于维度分别为16×64和64×8的矩阵A和B,以及维度为16×8的矩阵C和D。矩阵$E_A$和$E_B$的维度分别为16×2和2×8,它们分别存储矩阵A和B的每行和每列两个MX块的共享指数。
当一个warp执行MMA操作的机器指令时,该warp内的所有32个线程会共同为特定的块形状执行矩阵乘法。为了实现这一点,warp中的每个线程在其寄存器中持有一部分操作数矩阵的元素,称为片段(fragment)。
图8说明了对于4位mma.m16n8k64 PTX指令,矩阵A(MXFP4+激活)、矩阵B(MXFP4权重)以及结果矩阵D(FP32)的元素如何在warp中的线程间分布。对于矩阵A,每个线程在四个32位寄存器中持有一个片段,每个寄存器包含八个4位元素,而对于矩阵B,每个线程在两个32位寄存器中持有一个片段,每个寄存器也包含八个4位元素。对于矩阵D,每个线程在四个32位寄存器中持有四个32位元素。请注意,矩阵A的每一行对应两个MX+块,而矩阵B的每一列代表两个MX块。
在MXFP4+中,BM被有效地表示为E2M3,其私有指数为$e_{max}$,而Tensor Core中的FP4计算单元操作的是E2M1。因此,我们不能简单地对输入矩阵A和B执行MMA操作。为了解决这个问题,我们将每个MXFP4+块中的BM值分解为两个值的和,$BM_H$和$BM_L$,如下所示:
其中$BM_H = (-1)^s \times 2^{e_{max}} \times um[3:2]$,
其中$um[3:0]$表示一个尾数表示,其前导1显式存储在$um[3]$中,如同x86 80位扩展精度格式[20]。如公式3所示,$BM_H$和$BM_L$实际上是E2M1,可以存储在FP4中。因此,我们可以通过以下步骤处理矩阵A的BM元素。
* 将BM拆分为$BM_H$和$BM_L$。
* 用$BM_L$替换BM并执行MMA操作。
* 将$BM_H$与矩阵B中相应的元素相乘,并将结果累加到矩阵D中。
请注意,除了MMA操作之外,还需要额外的计算来获得矩阵D的正确输出(即针对$BM_H$的第三步)。解决这个问题的一个可能方法是利用CUDA核心进行FMA操作,而Tensor Core则为第二步执行MMA操作。然而,我们观察到,在RTX 5090 GPU上,与仅使用MX块的矩阵乘法相比,这种方法导致使用MX+和MX块的整体矩阵计算速度下降超过5倍。这是因为每个FP4元素必须转换为更高精度(例如BF16或FP32)才能在CUDA核心中执行FMA。此外,这还需要每个线程通过线程间通信(例如,warp shuffling)从其他线程获取数据。例如,在图8中,线程0(T0)需要从线程1和2(T1和T2)获取矩阵A中的$BM_H$s,并从线程1、2、5和6(T1、T2、T5和T6)获取矩阵B中匹配的操作数,以将乘法结果累加到矩阵D的前两个元素中。
5.2 使用未充分利用的Tensor Cores
为了避免昂贵的转换并减少线程间通信,我们转而为$BM_H$执行一个额外的MMA操作,同时重用每个线程的寄存器。我们观察到,这种方法在计算单元未被充分利用的解码阶段能维持性能,而在预填充阶段(第7.3节)推理时间仅有适度增加。
具体来说,每个线程首先加载矩阵片段和BM索引,然后使用其线程ID检查自己是否持有BM元素。如果一个线程发现自己持有BM,它就用$BM_L$替换BM。例如,在图8中,前四个线程(T0-T3)为矩阵A的第一个MX+块加载相同的BM索引,即8,并将其线程ID与$\lfloor \frac{\text{BM index}}{8} \rfloor$进行比较。T1识别到匹配并用$BM_L$替换BM。这个过程在持有矩阵A片段的四个寄存器中重复进行。
为了使用$BM_H$及其匹配的操作数执行一个额外的MMA操作,每个线程需要一个单独的寄存器中的矩阵A片段,其中只包含$BM_H$值,所有其他元素都设置为零。为了实现这一点,我们首先将每个线程分配给一个唯一的、排他的MX+块,从中提取相应的$BM_H$。例如,在图8所示的16×64矩阵A中,它包含32个MX+块,一个warp中的32个线程每个都处理一个不同的块并检索其$BM_H$值。然后,线程在寄存器中准备它们的矩阵A片段以进行额外的MMA操作。那些需要$BM_H$值的线程从提取线程中检索它们,并将它们放置在片段中相应的BM位置。
算法1展示了执行MXFP4+和MXFP4矩阵乘法的过程(MMALoop)。ReplaceBM(第9行)识别a中的BM并用$BM_L$替换它。MakeFragment(第11行)从BM中提取$BM_H$并将其存储在aBM的BM位置。请注意,这个过程的开销被分摊到多个for循环迭代中。最后,执行一系列MMA操作,包括一个针对$BM_H$的额外MMA(mma.sp.m16n8k128)(第21行),同时重用持有矩阵B和D片段的寄存器。我们执行一个稀疏MMA操作,其速度是密集MMA的两倍,因为除了$BM_H$外,矩阵A中的所有元素都为零。
6 MX+的架构支持
第5.2节介绍的MX+软件集成避免了转换并减少了线程间通信开销。然而,与MX格式相比,它需要一个额外的MMA操作。在本节中,我们探讨了将MX+硬件集成到支持MX格式的GPU系统中的方法。
6.1 GPU集成概述
我们提出了一种硬件设计,该设计支持MX+扩展,而无需对Tensor Cores中的点积引擎(DPE)进行侵入式更改。在MMA操作期间,Tensor Cores从一个warp中的不同线程收集用于矩阵乘法的输入操作数。我们在DPE输入端捕获BM及其匹配的操作数,然后使用专用的低精度标量计算单元执行与BM相关的操作。这将所有修改都限制在DPE中核心点积流水线之外的区域。由于DPE执行点积时只需要计算少数与BM相关的操作数,因此延迟和面积开销都保持在可忽略的水平。下一节将详细描述这种方法。
6.2 Tensor Core集成
图9展示了带有MX+硬件集成的Tensor Core整体设计。我们的基线Tensor Core架构遵循先前工作 [50] 中的设计,不同之处在于每个warp在一个包含32个DPE的单个Tensor Core上执行。四个线程组成一个线程组,使用四个DPE,两个线程组结合形成一个八位组(octet)。一个warp由四个八位组组成,线程协同将操作数矩阵加载到中间缓冲区。
根据我们的基准测试,每个Tensor Core每16个周期完成一个FP4 mma.m16n8k64操作,这也可以从RTX 5090规范 [45] 中推断出来。由于一个warp中的每个线程在MMA执行期间为八对MXFP4块计算点积以产生四个输出元素(参见图8),我们将每个DPE配置为每两个周期处理一对MXFP4块;即,每个DPE每周期处理16个FP4输入对。每对MXFP6或MXFP8块每四个周期计算一次,因为FP8的吞吐量是FP4的一半,而FP6的吞torch量与FP8相匹配。当前正在处理的块的BM索引($A_{BMidx}$和$B_{BMidx}$)会相应地提供。
硬件扩展。DPE通过三个主要组件进行扩展:1)BM检测器,2)转发与交换单元(FSU),以及3)BM计算单元(BCU)。当MX+块和BM索引被送入DPE时,BM检测器检查BM索引($A_{BMidx}$和$B_{BMidx}$)并通过向相应的FSU发送1位$BMA$和$BMB$信号来激活它们。每个FSU由几个多路复用器和三态缓冲器组成,并与连接到BCU的数据通路共享。当BM(BMA,BMB)信号被设置时,FSU将BM输入及其匹配的操作数引导到BCU,同时向相应的DPE输入转发零。这确保了这些输入被排除在点积流水线的计算之外。为了支持FP6和FP8,我们可以配置FSU,使得偶数位置($2i$,其中 $i=0, 1, ..., 7$)的FSU共享一个数据通路,而奇数位置($2i+1$)的FSU共享另一个。这使得来自相邻FSU的4位输入也能被路由到BCU。然后,转发的输入在BCU中进行处理,如下所述。
Tensor Core内的BM计算。如图9(c)所示,BCU接收BM及其匹配的操作数以及BM索引作为输入。然后它执行以下计算:
其中$A_{BM}$和$B_{BM}$是矩阵A和B的BM,而$B_{NBM}$和$A_{NBM}$是它们匹配的操作数。第一和第二乘法项在MX++中根据$ \delta_A $和$ \delta_B $有条件地左移,这两个值是矩阵A和B在MX和MX++之间共享指数的差异。这些移位被编码在BM索引的保留3位中。请注意,这些操作完成得比DPE快,后者使用加法树执行元素级乘法和多级加法,并且不会导致流水线停顿;即,BM计算不影响MMA指令吞吐量。$Output_{BM}$然后被加到加法树的输出上,之后进行归一化并转换为FP32。
当A和B的BM索引相同时,我们简单地将$A_{NBM}$与$B_{BM}$交换,并将$B_{NBM}$设置为零,从而有效地只计算公式中两个相同项中的一个。我们设计的乘法器支持足够高的精度,以处理被乘数和乘数都是BM的情况。
SASS指令扩展。图10(a)显示了执行MXFP4 MMA操作的SASS指令,对应于第5.1节中描述的MXFP4 mma PTX指令。MMA指令中的寄存器标识符代表多个连续的寄存器[50],只有最低的寄存器标识符被编码在指令中。例如,第一个OMMA指令中的R12代表一个由四个寄存器组成的序列:<R12, R13, R14, R15>。请注意,包含共享指数的R0和R3是单个寄存器,而不是寄存器序列。
图10(b)展示了我们提出的SASS指令。MMA指令通过一个BM控制标志进行扩展,该标志指示输入操作数是否以MX+表示。我们通过为附加标志扩展1位控制信息;SASS指令包含未使用的位[22],我们通过nvdisasm [43] 对Blackwell指令也观察到了类似情况。我们还扩展了MMA指令以接受两个额外的源寄存器,每个寄存器包含矩阵A和B的两个8位BM索引,遵循共享指数的寄存器布局。
对于BM索引的指令编码,我们遵循稀疏MMA指令用于MX格式的方案。这些指令隐式地将持有有序元数据的寄存器与矩阵A共享指数的寄存器一起编码。这两个寄存器配对形成一个序列,指令中编码一个单一的寄存器标识符。我们采用相同的方法,将BM索引寄存器与共享指数寄存器配对。如图10(b)所示,分别包含矩阵A和B的BM索引的R32和R48被复制到寄存器R1和R4。在随后的OMMA指令中,R0和R3隐式地代表寄存器对<R0, R1>和<R3, R4>。
A4 实验环境
算法实现
* 在MX PyTorch仿真库 [39] 的CUDA扩展之上实现了MX+。
* 使用Hugging Face [62] 的预训练模型进行模型性能评估。
* 遵循先前工作 [52],将MX和MX+格式应用于LLM推理期间所有涉及点积操作的张量,包括语言模型头和KV缓存中的张量。
* 对于归一化等向量操作,使用BF16。
模型与工作负载
* 模型: 评估了多种大型语言模型:OPT-66B [70], Llama-3.1 (8B和70B) [13], Mistral-7B-v0.3 [23], Phi-4-14B [1], Qwen-2.5-14B-Instruct [64]。
* 评估任务:
* 准确率: 使用与先前工作 [52] 相同的lm-evaluation-harness任务来衡量任务特定准确率(%)。
* 困惑度: 在WikiText-2 [38]和C4 [10]数据集上评估语言建模性能。
* 推理时间: 使用不同大小的Llama-2模型 [58] 来测量MX+集成下的推理时间。
硬件配置
- 软件集成场景1 (转换计算): 使用NVIDIA RTX A6000 GPU [40],该GPU缺乏原生MX支持。
- 软件集成场景2 (直接计算): 使用NVIDIA RTX 5090 GPU [45],该GPU提供原生MX格式硬件支持。
- 硬件集成仿真: 使用扩展的AccelSim [27]进行性能评估,配置类似于NVIDIA RTX 5090 GPU。RTL设计使用Synopsys Design Compiler在商用28nm工艺节点上进行综合。
软件配置
* 量化方法: 采用训练后量化(PTQ),具体为“直接转换”(direct-cast)推理,即不进行任何重新训练或微调,直接将预训练的BF16模型转换为MX或MX+格式进行评估。
* 软件库:
* 场景1: 扩展Triton编译器 [57] 以支持MX+的转换。
* 场景2: 使用CUTLASS库 [56] 实现算法,并将其矩阵乘法核集成到vLLM [31]中。
* 仿真: 使用CUTLASS库生成矩阵乘法轨迹。
A4 实验结果
7.2 语言模型性能
任务准确率 (lm-evaluation-harness)
- 总体趋势: 如表2所示,MX+在其对应的MX格式上均实现了准确率提升。MXFP8+和MXFP6+的准确率提升分别高达10.00和4.00个百分点。
- 低比特优势: MXFP4和MXFP4+之间的准确率差异尤为显著(提升高达+42.15%,不包括MXFP4无法工作的OPT-66B)。
- 激活值量化的重要性: 即使仅对激活值使用MXFP4+(A-MXFP4+),而权重仍使用MXFP4,其性能也远超完全使用MXFP4的情况。这表明在低精度格式中,表示激活值离群值极具挑战性,而MX+通过为BM提供更高精度有效解决了此问题。
- MX++的效果: 在MXFP4+的基础上,MXFP4++通过更精确地表示NBM,进一步将准确率提高了多达+4.63%。
语言模型困惑度
* 结果: 与准确率结果一致,表3显示,在不同序列长度和数据集上,MX+和MX++的困惑度始终低于原始MX格式。
7.3 MX+软件集成性能
场景1:计算前转换
* 结果: 表4显示了使用BF16激活和MXFP4+(或MXFP4++)权重的矩阵乘法执行时间,该时间已对MXFP4权重的情况进行归一化。结果表明,在激活值较小(数据复用率低)时,处理BM的开销更为明显。在数据复用率高的情况下,BF16 MMA操作主导了总时间,从而摊销了转换开销。在两种情况下,转换过程中处理BM所需的额外开销相对于MX都非常小。
表4: BF16激活与MXFP4+或MXFP4++权重的矩阵乘法时间,相对于MXFP4权重情况进行归一化。
场景2:直接计算
* 执行时间分解: 图11(a)展示了Llama-2-13B在prefill和decode阶段的执行时间。A-MXFP4+的性能接近MXFP4。由于主导执行时间的decode阶段是内存密集型的,A-MXFP4+中额外的MMA操作仅带来可忽略的性能开销(6.71%)。在prefill阶段,A-MXFP4+表现出中等程度的减速(1.54倍),该阶段占总执行时间的18.78%。
* 不同输出长度下的性能: 图11(b)展示了不同输出token数下的执行时间,相对于MXFP4进行了归一化。A-MXFP4+最多导致1.13倍的减速,而MXFP8则高达1.85倍。随着输出token数的增加,decode阶段占比更大,MXFP4和A-MXFP4+之间的性能差距随之缩小。
7.4 MX+硬件集成
性能
* 结果: 图12显示了在prefill阶段,采用硬件集成的MXFP4+相对于MXFP4的归一化执行时间。总体而言,MXFP4+平均比MXFP4慢0.38%。这是因为BCU的计算不影响MMA指令的吞吐量,额外的寄存器文件访问和增加的指令延迟对性能的影响微乎其微。
面积与功耗
* 结果: 表5列出了每个Tensor Core为支持MX+而增加的组件的面积和功耗。在32个DPE中,我们增加了16个FSU、一个BM检测器和一个计算单元。该设计的面积为0.020mm²,功耗为12.11mW。在28nm工艺节点下,这个开销非常小。
表5: 每个Tensor Core支持MX+的面积和功耗。
端到端加速
* 结果: 图13展示了在vLLM中相对于BF16的加速比以及Llama-2-13B在lm-eval-harness任务上的平均准确率。
* 硬件支持: MXFP4+在prefill和decode主导的场景中分别实现了3.34倍和2.73倍的加速,与MXFP4相当,但准确率提高了20.17%。MXFP4++的性能也很有竞争力,仅比MXFP4慢1.00%-1.04%。
* 软件支持: 在长输出场景下,A-MXFP4+实现了接近MXFP4的加速比,同时准确率高出17.46%。
量化时间
- 结果: 表6显示了不同输入token长度下的量化时间。MXFP4+的量化时间与MXFP4相似,而MXFP4++略有增加。由于量化仅占推理时间的一小部分,此开销对整体性能影响可忽略不计。
表6: 相对于MXFP4归一化的总量化时间。
A7 补充细节
8.1 与其他方案的比较
在本节中,我们比较了MX+与其他纯算法或算法-硬件量化方案的模型性能。为了公平比较,我们仅量化权重和激活之间的矩阵乘法,不包括语言模型头——这是所有方案中量化操作的交集。
SmoothQuant [63] 重新缩放激活通道,而QuaRot [3] 使用正交矩阵旋转激活以减小整体幅度。Atom [73] 重排通道并用更高精度(INT8)量化带有离群值的通道。表7显示,SmoothQuant(SMQ)在4位精度下表现不佳,正如多项研究[【3,QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs,2024,NeurIPS】,【32,Tender: Accelerating Large Language Models via Tensor Decomposition and Runtime Requantization,2024,ISCA】,【35,LLM-FP4: 4-Bit Floating-Point Quantized Transformers,2023,EMNLP】]所讨论的。我们观察到QuaRot并不能完全消除离群值,并且性能比MXFP4+差;离群值的幅度在旋转后并未减小(例如,Llama-3.1中的下投影层)。与QuaRot不同,MX+专注于精确表示重要的离群值,因为MX的细粒度分组已经限制了离群值对其他值的影响。Atom显示出可比的模型性能,因为它用8位表示离群值,但仍然比MX+差。
表7: 在WikiText-2上通过直接转换推理的困惑度。
ANT [16] 和OliVe [15] 使用自定义格式,而Tender [32] 对相似范围的通道进行分组并使用标准的INT4。如表所示,由于粗糙的通道分组或张量级分组,它们在4位精度下性能受损。为了仅用于准确性比较,我们将这些方案扩展以支持更细粒度的分组,称为MX-ANT、MX-OliVe和MX-Tender,尽管这会显著增加它们的运行时开销。MX-Tender在运行时对每两行的通道进行分组。MX-ANT和MX-OliVe支持大小为32的组级量化。两者都自适应地为权重选择每组数据类型,为激活选择每张量数据类型。所有方案都使用在运行时按组计算的浮点缩放因子。尽管如此,MX+仍然表现出更好的性能,证明了用高精度表示BM的有效性。
8.2 MX+的更广泛适用性
仅权重量化。虽然MX+主要针对我们希望对权重和激活张量都使用低比特精度的场景,但它也为以权重量化为重点的场景提供了好处。表8显示了当权重采用4位数据格式,而激活使用AWQ [34] 的BF16或MXFP8时的困惑度。AWQ是一种仅权重量化方法,它将重要的权重通道缩放到更大的幅度以在低比特量化下保护它们。虽然直接将MXFP4与AWQ一起使用会降低模型性能,但MX+可以与AWQ协同工作。这是因为放大重要通道使得更多重要的权重元素被识别为BM。因此,与原始AWQ(权重INT4)相比,模型性能得到进一步提升。当使用MXFP8激活和MXFP4权重时,权重的精确表示可能比激活更关键,因为它使用的比特数只有一半。在这种设置下使用MXFP4+可以显著提高模型性能,如表所示。
表8: 在BF16激活与AWQ (A16W4) 和MXFP8激活 (A8W4) 下,不同权重格式的直接转换困惑度。
其他DNN工作负载。我们还在Vision Transformer [59] 和CNN模型 [18] 上评估了MX+在ImageNet数据集 [53] 上的图像分类任务。表9显示,在直接转换推理下,MXFP4+的准确性高于MXFP4,对于DeiT和ResNet模型,改进分别高达+4.81%和+13.38%。正如先前工作[【11,Packqvit: Faster Sub-8-bit Vision Transformers via Full and Packed Quantization on the Mobile,2023,NeurIPS】,【35,LLM-FP4: 4-Bit Floating-Point Quantized Transformers,2023,EMNLP】,【54,DRQ: Dynamic Region-based Quantization for Deep Neural Network Acceleration,2020,ISCA】,【72,Improving Neural Network Quantization Without Retraining Using Outlier Channel Splitting,2019,ICML】]所讨论的,我们观察到这些模型中也存在激活离群值,并且通常散布在MX块中。MX+更精确地表示了这些离群值,从而提高了准确性。我们还进行了量化感知(QA)微调,并评估了MXFP4+在微调模型上的有效性。与直接转换推理相比,MXFP4和MXFP4+之间的准确性差距变小了,因为用于图像分类任务的微调模型即使使用MXFP4也能达到接近FP32基线的准确性。然而,对于更复杂的模型和更具挑战性的任务,如果微调的MXFP4模型无法达到FP32级别的准确性,MXFP4和MXFP4+之间的准确性差异可能会更加明显。
表9: ImageNet上的Top-1准确率 (%)。
MX+对非FP微缩放格式的适用性。除了三种MXFP变体外,OCP MX规范还定义了一种具有整数元素数据类型的额外MX兼容格式:MXINT8。尽管MXINT8在其元素数据类型中没有指数域,但为BM元素增加额外精度的方法同样可以应用于MXINT8。例如,MXINT8中的INT8编码使用一个符号位、一个整数位和六个小数位。在这种配置下,公式1中的$e_{max}$等于零,因为元素值总是小于2。共享指数简单地成为BM值的指数,而BM元素则以±1.xxxxxx的格式表示。这使我们有可能将整数位设为隐式,并将其用作BM元素的额外小数位。表10显示了将此方法应用于MXINT8和假设的MXINT4格式(一个符号位、一个整数位、两个小数位)的困惑度结果。对于MXINT8,将小数位从六位增加到七位几乎没有帮助。相比之下,MXINT4从额外的小数位中获益,类似于MXFP4+或MXFP6+。如果MXINT4成为具体的MX兼容格式的一部分,这个方向也可能值得探索。
表10: 在直接转换设置中,非FP微缩放格式在WikiText-2上的困惑度。
MX+对NVFP4的适用性。NVIDIA最近推出了NVFP4,这是一种与MXFP4相似的4位浮点格式。两种格式都使用FP4 (E2M1) 元素,每个块共享一个缩放因子。然而,NVFP4的不同之处在于它使用更小的块大小(16个元素)和E4M3缩放因子。表11展示了NVFP4的直接转换准确性结果。与表2中的结果相比,MXFP4+和MXFP4++的性能优于或与NVFP4相当。这是因为由于BM的额外精度,离群值在MX+中通常被更准确地表示。MX+扩展可以类似地应用于NVFP4,因为MXFP4和NVFP4在计算缩放因子时都将BM尽可能地映射到FP4中可表示的最大幅值[42]。与MX+类似,我们扩展了NVFP4中BM元素的尾数位,除非BM的幅值极小,以至于元素数据类型中的指数未被设置为最大值(即当共享缩放因子$scale_{E4M3} \le 00000010_2$时)。在这种情况下,我们对该块使用原始的NVFP4表示。请注意,通过额外的每张量软件缩放步骤,可以将值移至更大的幅度以进行每块缩放,从而减少这种情况的发生频率。扩展后的NVFP4,称为NVFP4+,每块增加4位来存储BM索引,这可以与其他块的BM索引打包以实现字节对齐。如表11所示,NVFP4+的准确性高于NVFP4。
表11: NVFP4和NVFP4+ (BM具有额外精度的NVFP4) 在lm-eval-harness任务上的直接转换推理准确率。
在收缩阵列变体中支持MX+。除了在GPU上执行矩阵乘法,也可以考虑使用固定功能的矩阵流水线,如TPU [25]。这些流水线通常实现权重固定或输出固定的收缩阵列设计[【25,In-Datacenter Performance Analysis of a Tensor Processing Unit,2017,ISCA】,【26,Intel Gaudi 3 AI Accelerator: Architected for Gen AI Training and Inference,2024,HCS】],其中每个处理单元(PE)每周期执行一个MAC操作。在这些流水线中支持MX+可以通过向数据通路添加FSU和BCU来完成,类似于GPU集成。例如,在一个代表性的32×32 MX兼容收缩阵列中,每个PE附加一个FSU,每列的PE共享一个BCU。在权重固定的数据流中,一列中的PE共同执行一个MX块对的点积。BCU位于收缩阵列下方,接收由FSU转发的BM值及其匹配的操作数,以及部分和。然后它计算与BM相关的操作数,将结果添加到部分和中,并将值转发到累加器。对于输出固定的数据流,过程类似。每个PE在32个周期内执行一个MX块对的点积,FSU收集与BM相关的操作数。在这些周期之后,操作数和部分和被转发到BCU。更新后的部分和随后被路由回每个PE的累加器所在位置。
8.3 处理块中的多个离群值
虽然我们保持所提算法的简洁性以最小化最终用户的开销并实现与各种框架的无缝集成,但当MX+被优化以捕获同块共存的离群值时,模型性能可以得到进一步提升。
离群值分析。我们对每个块中幅值最大的k个元素使用MXFP6表示,而其他元素保持MXFP4。图14显示了困惑度和激活张量中以MXFP6表示的离群值百分比。我们使用3σ规则[15]来识别离群值,并专注于激活,因为权重对模型性能的影响通常较小。结果表明,扩展MX+以存储额外的BM索引来追踪多达两个离群值可以带来一些增益,而表示更多的离群值则显示出递减的回报。这是因为大多数激活离群值在top-2时已经被更高精度表示。为了平衡复杂性和模型性能,我们选择通道重排作为MX+之上的一个可选优化,以显式地分离同一块中的离群值。当应用通道重排与MX+时,困惑度和百分比紧密跟随top-2的情况,如图14所示,我们将在本节剩余部分详细讨论。
通过通道重排分散离群值。如图4所示,激活离群值通常集中在通道粒度上。为了让更多的离群值被识别为BM,我们也可以通过通道级重排来显式地将离群值分散到不同块中。例如,我们首先根据离群值的数量对每个激活的通道进行排序。离群值最多的通道然后被放置在每32个(即块大小)通道中的一个。剩余的已排序通道被分成两半,我们将下半部分的通道按降序排列在剩余位置,然后以同样的方式排列上半部分的通道。表12显示了通道级重排的准确性结果,表示为Reorder。改进源于更精确的离群值表示。例如,在一个采样的查询矩阵中,经过重排后,含有离群值的块中具有多个离群值的块的百分比从22.52%下降到4.58%。对于每个任务,我们通过使用10%的样本对查询和键矩阵之间的每个通道的离群值计数进行平均,来预先确定它们的通道顺序。两个矩阵使用相同的通道顺序以保持数学正确性。重排与量化融合在一起,通过将每个量化输出存储在其重排后的通道地址,使得重排开销可以忽略不计。
表12: lm-eval-harness任务上的直接转换推理准确率。Reorder表示对查询和键矩阵应用了通道重排的MXFP4+。
A5 结论
服务大型语言模型(LLM)需要大量的计算和内存资源,而由行业领先公司开发的MX数据格式正日益被采用以缓解这些挑战。在这项工作中,我们研究了采用MX格式进行LLM推理的影响,并发现由于包含离群值的激活张量的量化,模型性能在超低精度下会显著下降。基于块内绝对值最大元素不需要在元素数据类型中存储其指数这一洞察,我们提出了MX+,一种对MX的非侵入式扩展,它将指数字段重新用作扩展的尾数。MX+在不增加复杂性的情况下,显著提高了各种精度下的模型性能,在较低比特宽度下增益更大。它还能在软件集成场景中实现直接部署,推理过程中的开销极小,而通过硬件支持,这一开销几乎可以被消除。
💬 评论讨论
欢迎在这里分享您的想法和见解!