MiniMax Sparse Attention

发表时间: 2026-06 · arXiv:2606.13392 (MiniMax)

文章标题:MiniMax 稀疏注意力
作者/机构:Xunhao Lai (MiniMax, 北京大学), Weiqi Xu (MiniMax), Yufeng Yang (MiniMax), Qiaorui Chen (NVIDIA), Yang Xu (MiniMax, 浙江大学), Lunbin Zeng (MiniMax, 华中科技大学), Xiaolong Li (MiniMax, 浙江大学), Haohai Sun (MiniMax), Haichao Zhu (MiniMax), Vito Zhang (MiniMax, 北京大学), Pengyu Zhao (MiniMax)


A1 主要贡献

本文旨在解决前沿大语言模型(LLM)中日益增长的超长上下文处理需求所带来的计算瓶颈。诸如智能体工作流、代码库级别的代码推理和持久化记忆等应用,要求模型能够同时处理数十万到数百万的词元(token),而标准 Softmax 注意力的二次方计算成本使得这在部署规模上变得不可行。

研究目标
本文旨在提出一种高效、可扩展且能实际加速的稀疏注意力机制,以在保持模型性能的同时,打破超长上下文处理的计算瓶颈。

核心创新点
为了应对这一挑战,本文提出了 MiniMax 稀疏注意力(MiniMax Sparse Attention, MSA),这是一种基于分组查询注意力(Grouped Query Attention, GQA)构建的块级稀疏注意力机制。其核心设计如图1所示:

图1 | MSA概览。索引分支(左)用一个轻量级头为整个因果上下文打分,并为每个查询和GQA组独立选择一个Top-k的关键值块集合I;本地块无论得分如何总被包含。主分支(右)仅对选定的块执行精确的块稀疏注意力,并产生层输出。训练期间,一个KL损失将索引分布与在选定块上的组平均主分支分布对齐,并且索引分支的梯度与主分支分离。
图1 | MSA概览。索引分支(左)用一个轻量级头为整个因果上下文打分,并为每个查询和GQA组独立选择一个Top-k的关键值块集合I;本地块无论得分如何总被包含。主分支(右)仅对选定的块执行精确的块稀疏注意力,并产生层输出。训练期间,一个KL损失将索引分布与在选定块上的组平均主分支分布对齐,并且索引分支的梯度与主分支分离。

MSA 的主要贡献如下:
* 提出 MSA 机制:一种极简、可扩展且可加速的块级稀疏注意力机制。它既支持从头开始训练,也支持从预训练的 GQA 检查点进行近乎无损的转换。其架构包含一个轻量级的索引分支(Index Branch)和一个主分支(Main Branch)。索引分支为每个 GQA 组独立选择 Top-k 的键值(KV)块,而主分支仅在这些选定的块上执行精确的注意力计算。
* 协同设计高效 GPU Kernel:为了将 MSA 的理论计算节省转化为实际的端到端加速,本文协同设计了高效的训练和推理 GPU 核(kernel)。这些核包括用于 Top-k 选择的无指数(exp-free)计算核,以及采用 KV-outer 稀疏注意力来提高张量核心(tensor-core)利用率的计算核,从而在块级粒度访问下实现高效执行。
* 大规模实验验证:在一个包含原生多模态训练的 109B 参数混合专家(MoE)模型上进行了广泛的消融实验。实验结果表明,MSA 的性能与 GQA相当,但在 1M 上下文长度下,将每个词元的注意力计算量减少了 28.4 倍。结合协同设计的核,MSA 在 H800 GPU 上实现了 14.2 倍的预填充(prefill)和 7.6 倍的解码(decoding)壁钟时间加速。


A3 背景知识

2.1. 因果注意力与 GQA

因果 Softmax 注意力的计算与成本:对于序列长度为 $N$、隐藏维度为 $d_{model}$、头维度为 $d_h$ 的模型,在每个查询位置 $i$ 和头 $h$,因果 Softmax 注意力的计算公式如下:

该公式的计算成本为 $\Theta(2N_{q}N_{kv}d_{h})$ FLOPs,与序列长度 $N$ 呈二次方关系。

分组查询注意力(GQA):GQA(Grouped-Query Attention)【1,GQA: Training generalized multi-query transformer models from multi-head checkpoints,2023,EMNLP】使用 $N_{q}$ 个查询头,但将键值头的数量减少到 $N_{kv}$ 个,并将 $g = N_{q}/N_{kv}$ 个相邻的查询头绑定到一个共享的键值头。因此,每个键值头定义了一个 GQA 组。

2.2. 作为两阶段过程的稀疏注意力

稀疏注意力的分解:一个稀疏注意力层将因果注意力分解为两个阶段:一个索引器(indexer)选择要关注的键(key),以及在选定的键上进行稀疏注意力计算。对于每个查询位置 $i$,其计算过程可以表示为:

其中,$\text{Index}_i$ 由 $\theta$ 参数化(对于固定规则的索引器,$\theta$ 为空;对于可训练的索引器,$\theta$ 是学习的),$I_i \subseteq \{1, . . . , i\}$ 表示选定的索引集,而 $\text{Attn}$ 表示仅限于该索引集的标准缩放点积 Softmax 注意力。我们将第一阶段称为索引分支(Index Branch),第二阶段称为主分支(Main Branch)。在多头注意力中,每个由位置 $i$ 和查询头 $h$ 指定的查询可以选择一个不同的键/值索引集,记为 $I^{(h)}_i$。

2.3. 基于 GQA 的块稀疏注意力

效率与粒度的权衡:逐头的词元级选择提供了最细的粒度,但这种细粒度的计算很难高效地映射到 GPU 的矩阵运算上。为了提高效率,基于 GQA 的稀疏注意力可以在每个 GQA 组内共享索引结果。设 $H_r$ 表示由第 $r$ 个键值头服务的 $g$ 个查询头,组共享的索引集可以写为:

块级选择:选择键/值块而不是单个词元,可以减少路由开销并使稀疏注意力更加规整。对于块大小为 $b_{kv}$,定义块划分如下:

对于查询位置 $i$ 和 GQA 组 $r$,集合 $I^{(r)}_i \subseteq \{1, . . . , N\}$ 表示选定的块索引集。然后,组 $r$ 中任何查询头的稀疏注意力输出都是在选定块中的因果可见词元上计算的,使用同一组的键值头。MSA 遵循这种基于 GQA 的块稀疏公式。


A2 方法细节

3. MSA

本文介绍了 MiniMax 稀疏注意力(MSA),一种基于 GQA 的双分支稀疏注意力机制,如图1所示。对于每个查询词元,一个轻量级的索引分支从因果上下文中选择一小组键块,然后主分支在这些块中的词元上计算 Softmax 注意力。索引分支仅在标准 GQA 的基础上增加了两个投影矩阵,在块粒度上操作,并为每个 GQA 组独立进行选择。

3.1. 架构

MSA 的两阶段实例化:MSA 在 GQA 组和块粒度上实例化了第2.2节中的两阶段稀疏注意力公式。对于每个查询词元,索引分支为每个 GQA 组选择 $k$ 个大小为 $b_{kv}$ 的键块,主分支仅关注选定块中的词元,其预算最多为 $k \cdot b_{kv}$。设 $X \in \mathbb{R}^{N \times d_{\text{model}}}$ 为输入隐藏状态。根据第2.1节,我们用 $N_q$ 和 $N_{kv}$ 分别表示查询头和键值头的数量,因此每个键值头服务于 $g = N_q / N_{kv}$ 个查询头。

索引分支(Index Branch):索引分支为每个 GQA 组引入一个索引查询头,并为所有组共享一个单一的索引键头:

对于查询词元 $i$ 和组 $r$,索引分支首先对可见的键词元进行评分,然后将这些分数聚合到块级别。使用第2.3节中定义的块划分 $B_1, . . . , B_N$:

这里 $r$ 索引 GQA 组,$j \le i$ 强制执行因果关系,没有可见词元的块被赋予 $-\infty$ 的分数。然后,索引分支选择得分最高的 $k$ 个块索引:

这里 $\text{TopK}(\cdot, k)$ 返回在 $s_{\text{idx},(r)}^{i, \cdot}$ 下 $k$ 个最大块的索引。我们总是包含包含位置 $i$ 的本地块,并且 $I^{(r)}_i$ 由组 $r$ 中的所有 $g$ 个查询头共享。

主分支(Main Branch):给定由索引分支选择的块索引集 $I^{(r)}_i$,主分支仅关注选定块中的因果可见词元。对于任何查询头 $h \in H_r$,它应用标准缩放点积注意力,仅限于这些词元,并使用与 GQA 组 $r$ 关联的键值头:

其中,$q^{(h)}_i$ 表示位置 $i$ 和查询头 $h$ 的查询向量,而 $K^{(r)}$ 和 $V^{(r)}$ 表示第 $r$ 个 GQA 组的键和值矩阵。符号 $K^{(r)}[I^{(r)}_i]$ 和 $V^{(r)}[I^{(r)}_i]$ 表示从选定块中收集因果可见的词元。块索引集 $I^{(r)}_i$ 由 $H_r$ 中的所有查询头共享,而每个头保留自己的查询投影。由于选定的块最多包含 $k \cdot b_{kv}$ 个因果可见词元,每个查询的注意力成本从 $\mathcal{O}(N)$ 减少到 $\mathcal{O}(k \cdot b_{kv})$,随着序列长度的增加而保持固定。

3.2. 训练

训练挑战与解决方案:公式7中的 Top-k 选择是不可微的,因此语言建模损失不能直接训练索引的查询/键投影矩阵 $W_{\text{idx}}^Q, W_{\text{idx}}^K$。因此,我们通过 KL 对齐损失来训练索引分支,并使用三种机制来稳定稀疏训练:梯度分离(Gradient Detach)、索引器预热(Indexer Warmup)和强制本地块(forced Local Block)。

KL 损失:KL 损失通过将其分数与主分支在选定词元上的分数相匹配,为索引分支提供了直接的学习信号。记 $I^{(r)}_{i, \text{tok}} = (\cup_{j \in I^{(r)}_i} B_j) \cap \{1, . . . , i\}$ 为由选定块索引引起的因果可见词元,对于每个查询位置 $i$ 和 GQA 组 $r$,我们在这个词元索引集上定义索引分支分布 $p_{\text{idx}}$ 和主分支教师分布 $\pi$:

其中 $s_{\text{idx}, (r)}^{i, j} = (Q_{\text{idx}})^{(r)}_i (K_{\text{idx}})_j^\top / \sqrt{d_{\text{idx}}}$ 是词元级索引分数,而 $s^{(\ell)}_{i, j} = q^{(\ell)}_i (K^{(r)}_j)^\top / \sqrt{d_h}$ 是查询头 $\ell \in H_r$ 的主分支分数。教师分布 $\pi$ 在概率层面平均了每个头的主分支分布。然后,索引器被训练来匹配 $\pi$,并在所有查询位置和 GQA 组上取平均:

其中 $N$ 是序列长度,教师分布 $\pi^{(r)}_{i, \cdot}$ 从梯度计算中分离。这个辅助损失将索引分布与主分支的注意力模式对齐,使得后续的块选择在语义上有意义。

梯度分离(Gradient Detach):为了将辅助目标与主干网络隔离,我们对索引分支的输入应用 stop-gradient:

公式9中的教师分布 $\pi$ 是分离的,因此 $L_{\text{KL}}$ 不会影响主分支的投影;公式11进一步阻止它通过 $X$ 到达主干网络。在此规则下,$L_{\text{KL}}$ 仅更新 $W_{\text{idx}}^Q$ 和 $W_{\text{idx}}^K$,使 KL 损失成为索引器的一个纯粹的对齐信号。

索引器预热(Indexer Warmup):我们使用一个两阶段的训练计划来初始化索引分支并避免早期的随机选择。在最初的几次迭代中,模型在两个分支中都运行全注意力,并用 $L_{\text{KL}}$ 训练新添加的索引投影。预热后,模型切换到稀疏注意力,并在 Top-k 选定的位置上计算 $L_{\text{KL}}$。当稀疏化一个预训练的全注意力检查点时,也使用相同的计划,这有助于在新增的索引投影控制主分支路由之前将其对齐。

本地块(Local Block):对于每个查询位置 $i$ 和 GQA 组 $r$,包含 $i$ 的本地块在训练和推理期间总是作为 $I^{(r)}_i$ 的一部分被选中。这种固定分配保留了一个块槽位,并将剩余的槽位留给索引分支选择,防止了忽略查询紧邻区域的退化选择。

完整的层级训练过程总结在算法1中。

3.3. 计算复杂度

FLOPs 对比:在相同的 $N_q, N_{kv}, d_h$ 和序列长度 $N$ 下,GQA 和 MSA 的因果注意力 FLOPs 分别为:

GQA 的主注意力路径随整个上下文长度扩展,而 MSA 使用固定的选择预算 $k \cdot b_{kv}$ 加上一个轻量级的索引计算。因此,当 $k \cdot b_{kv} \ll N$ 且 $d_{\text{idx}} \ll d_h$ 时,FLOPs 的差距会随着 $N$ 的增加而增大。

4. Kernel 设计

本节描述了我们稀疏预填充(prefill)实现中使用的 GPU 核,包括索引 TopK 核、KV-outer 稀疏注意力前向传播核以及稀疏 KL 损失反向传播核。

4.1. 索引与 TopK

无指数选择(Exp-free selection):为了高效地选择 Top-k 的 KV 块,索引模块直接对索引分数 $s$ 进行排序。由于 softmax 是保序的($s_i \le s_j \iff \text{softmax}(s)_i \le \text{softmax}(s)_j$),分数的相对顺序得以保留,从而使 Top-k 索引保持不变。因此,前向传播绕过了 softmax 的 max/exp/sum 步骤,直接将原始分数传递给选择过程。

每线程寄存器 Top-k:块大小 $b_{kv}$ 和选择大小 $k$ 与 Top-k 核协同设计:较大的 $b_{kv}$ 提高了注意力计算的算术强度(第4.2节),而在此 $b_{kv}$ 下较小的 $k$ 使得每行的候选块数 $N$ 和 $k$ 都低于通用 Top-k 核的最佳点。我们采用 $b_{kv} = 128, k = 16$。warp 中的32个 lane 各自以 1/32 的步幅流式处理输入行,并在共享内存中维护一个 $k$ 元素的最小堆。堆的根缓存在寄存器中,并使用延迟写入执行插入。最后,一个 $k$ 轮的 shuffle merge 合并32个局部 TopK 结果。共享内存布局将每个 lane 映射到一个固定的 bank,避免了冲突。

基准测试:我们在 H800 GPU 上,使用 fp32 输入和未排序输出,与 torch.topk 和 TileLang【67,Tilelang: A composable tiled programming model for ai systems,2025】的基数选择 Top-k 进行了比较;延迟是50次预热后迭代的中位数。表1显示,我们的专用核在所有测试设置中都是最快的,在部署设置 $k=16$ 时增益最大。


表1 | 对形状为 $(B, N)$ 的 fp32 输入进行 Top-k 计算的延迟(μs),行独立处理。部署设置使用 $b_{kv} = 128, k = 16$,作为参考,我们还报告了 $b_{kv} = 64, k = 32$ 的情况。所有实现都产生相同的索引集。

4.2. 稀疏注意力

迭代顺序选择:我们重新审视在查询和键/值长度相等的稀疏预填充下迭代顺序的选择。设 $N_q, N_{kv}, g = N_q/N_{kv}, d_h, N, b_{kv}, k$ 分别为查询头数、键值头数、GQA 比例、头维度、序列长度、KV 块大小和每个查询选择的块数。为简单起见,下面的 IO 估计假设元素为2字节(bfloat16大小的流量)。

KV-outer 实现:由于在实践中 $2 \cdot b_{kv} \gg d_h$,我们选择 KV-outer 迭代与 Q 收集来最大化算术强度。该核作为一个持久化网格在 (kv_block, kv_head) 瓦片(tile)上执行。对于每个瓦片,一个来自 TopK 选择的反向稀疏索引识别出相关的查询位置。这些查询通过 TMA 拷贝加载到共享内存中,每个查询词元一个,由一个 warp 的32个 lane 并行分派。

预调度瓦片分块(Pre-scheduled tile chunking):直接的“一个CTA一个瓦片”映射会被“沉没行”(sink rows)主导——即一个早期的 KV 块被几乎每个查询选中。为了解决这个问题,一个 GPU 调度核沿着查询维度将每个 KV 瓦片分割成最多约 $2 \cdot b_{kv}$ 个查询的块,将热门瓦片分散到多个共享相同 K/V 加载的 CTA 上。由于每个查询的 $O$ 部分现在由 $c$ 个 CTA 产生,调度器还为每个 (查询, 块) 对预分配一个在 Obuf 中的槽位 $s \in [0, c)$,与查询索引 $i$ 一起打包成一个32位句柄。这样,注意力核就可以在没有原子操作的情况下将其部分结果写入预分配的偏移量。合并核读取每个查询的槽位计数,以知道要合并多少个部分结果。

两阶段前向传播(Two-phase forward):KV-outer 的分割方式禁止了在线 softmax 归一化。因此,前向传播被分成两个核,由 HBM 缓冲区 Obuf(局部归一化的部分输出)和 LSEbuf(每个部分结果的 logsumexp)隔开。注意力核执行上述工作列表并将每个部分结果写入其预分配的槽位。合并核读取每个查询的有效槽位,计算全局 logsumexp,然后形成归一化的权重来组合部分结果,最终输出 $O[i, h]$ 和最终的 LSE。

查询拼接(Query concatenation):在 KV-outer 迭代中,一个 KV 瓦片通常只与少数几个到几十个查询位置相关联。逐个处理这些位置会使分数矩阵乘法(MMA)填充不足。然而,在 KV-outer 迭代下,为给定瓦片收集的所有位置共享相同的 KV 操作数。因此,该核将 $\lceil 128/g \rceil$ 个查询位置及其 $g$ 个关联的查询头(都在同一个 KV 头下)打包成一个 128 × 128 的分数 MMA,以提高利用率。

4.3. 稀疏 KL 损失

LSE 融合(LSE fusion):我们优化了 KL 散度计算,通过在主传播过程中直接将 $LSE_{\text{main}}$ 和 $LSE_{\text{idx}}$ 输出到全局内存,从而完全跳过 KL 损失的前向传播过程。此外,在索引分支计算期间,我们保存每个块的 LSE,并在 Top-k 块上执行归约以获得 $LSE_{\text{idx}}$。反向传播核然后直接将这些标量加载到 softmax 中,消除了冗余的前向计算。

动态负载均衡(Dynamic load balancing):在可变长度序列和数据依赖的稀疏性下,每个瓦片的工作量差异巨大。该核作为一个持久化网格运行,其中 CTA 通过全局原子计数器声明工作。每个瓦片沿着其收集的查询维度被划分为子瓦片,其数量与每个瓦片的查询数成比例,并受最小子瓦片粒度的限制,以分摊每个子瓦片的开销。


A4 实验环境


A4 实验结果

5.2. 训练动态

从头预训练(MSA-PT):如图2所示,MSA-PT与全注意力基线在3T词元的训练过程中,其语言模型损失曲线几乎无法区分,表明MSA相对于全注意力没有引入明显的优化退化。梯度范数曲线在整个训练过程中也保持在相同范围内,说明MSA并未导致异常的梯度波动或训练不稳定。

图2 | 实验模型的预训练动态。显示了全注意力和MSA-PT在3T训练词元上的LM损失和梯度范数。 (a)中的插图放大了最后50B词元窗口,两条LM损失曲线几乎重叠。

继续预训练(MSA-CPT):如图3所示,从全注意力检查点过渡到稀疏继续预训练时,索引器预热阶段在启用稀疏注意力之前迅速降低了KL损失。切换到稀疏CPT后,KL损失保持在低位。块召回率(Block recall)和分数召回率(Score recall)都保持在较高水平,表明索引器能可靠地恢复重要块,并且检索到的块占据了主分支绝大部分的注意力权重。

图3 | 稀疏继续预训练动态。(a) MSA-CPT期间的平均KL损失。实线段表示索引器预热,虚线段表示稀疏继续预训练;垂直虚线标记了两个阶段的切换。(b) 稀疏继续预训练期间MSA-CPT索引器的平均块召回率和分数召回率。

5.3. 主要结果

基准评估:如表2所示,两种稀疏模型(MSA-PT和MSA-CPT)在广泛的预训练评估中与全注意力基线相比具有竞争力,表明用MSA替换密集注意力不会显著降低模型的通用语言、推理、多模态或面向智能体的能力。
* MSA-PT 在许多数学、图像、视频和长上下文检索基准上取得了最好的结果,表明从头稀疏预训练可以使模型表示适应稀疏注意力模式。
* MSA-CPT 更为保守,保留了大部分全注意力检查点的行为,在大多数文本、代码和PPL评估上表现接近,是已有密集检查点时的实用转换路径。

表2 | 3T词元训练预算下的代表性评估结果。Full表示全注意力基线,MSA-PT表示从头稀疏预训练,MSA-CPT表示稀疏继续预训练。每行最佳结果加粗;PPL越低越好,其他越高越好。

长上下文扩展:如表3所示,对MSA-CPT模型进行约140B词元的长上下文训练后,在HELMET和RULER上的表现依然接近全注意力基线。这表明即使在每个查询和GQA组只关注2048个KV词元的极度紧张的注意力预算下,MSA也能保持长上下文能力。

表3 | MSA-CPT在HELMET和RULER上的长上下文扩展结果。Δ报告了MSA-CPT与全注意力基线之间的差异。“Overall”分数是各细分任务的平均值。所有指标越高越好。

5.4. 效率

理论与实际加速:在实验模型配置下,MSA的注意力预算固定为2048个词元。如图4所示,与GQA相比,MSA显著减少了每个词元的注意力FLOPs,在1M词元上下文时,FLOPs减少了28.4倍。
* 实际加速:测量的运行时加速也遵循相同的扩展趋势。尽管由于稀疏注意力的索引、选择、查询收集等开销以及不规则的内存访问模式,运行时加速小于理论FLOPs减少,但随着上下文长度增加,加速效果愈发明显。在1M上下文时,MSA在H800上实现了14.2倍的预填充(prefill)7.6倍的解码(decode)壁钟时间加速。


图4 | GQA与MSA在共享实验模型配置下的效率比较。左子图报告了理论上的每词元注意力FLOPs。中、右子图分别报告了预填充和解码的实测实现加速。所有测试均在64个查询头、4个键值头和128的头维度下进行。MSA使用$b_{kv}=128$和$k=16$,对应每个查询2048个词元选择预算。


A7 补充细节

6. 相关工作

高效注意力机制:长上下文效率催生了大量关于高效注意力的研究,主要分为两个方向:
1. 替代方案:用计算成本更低的线性或循环替代方案替换稠密的Softmax注意力。例如,线性注意力【8,Rethinking attention with performers,2021,ICLR】和状态空间模型如Mamba【19,Mamba: Linear-time sequence modeling with selective state spaces,2023】。混合架构【40,MiniMax-01: Scaling foundation models with lightning attention,2025a】则交错使用线性和全注意力块。
2. 限制感受野:保留Softmax注意力,但限制其感受野。固定模式注意力包括局部窗口、全局词元【4,Longformer: The long-document transformer,2020】和带滑动窗口的注意力汇(attention sinks)【70,Efficient streaming language models with attention sinks,2024b】。

自适应稀疏注意力:与固定模式不同,自适应稀疏注意力使关注的支持集依赖于输入。现有方法主要区别在于支持集的构建时机以及选择器是否作为模型一部分进行训练。
* 推理时稀疏化:在预训练的全注意力主干上操作,仅在服务时构建稀疏支持。例如,H2O【76,H2O: Heavyhitter oracle for efficient generative inference of large language models,2023】和SnapKV【28,SnapKV: LLM knows what you are looking for before generation,2024】在解码时修剪KV缓存。这些方法继承了全注意力的训练成本。
* 原生训练的稀疏注意力:在预训练期间训练索引器,是与MSA最接近的先前工作。例如,NSA【74,Native sparse attention: Hardware-aligned and natively trainable sparse attention,2025】、InfLLM-V2【77,Infllm-v2: Dense-sparse switchable attention for seamless short-to-long adaptation,2025a】、MoBA【33,Moba: Mixture of block attention for long-context llms,2025】和DSA【15,Deepseek-v3.2: Pushing the frontier of open large language models,2025】。MSA与这些方法的不同之处在于两个协同采用的轴心:每个GQA组的Top-k共享与块级选择相结合,这在保持KV读取连续性的同时,实现了多组块粒度的检索。

高效核(Kernels):高效的核对于将稀疏注意力的理论FLOP减少转化为壁钟时间加速至关重要。FlashAttention【12,FlashAttention: Fast and memory-efficient exact attention with IO-awareness,2022】和FlashAttention-2【11,FlashAttention-2: Faster attention with better parallelism and work partitioning,2024】引入了IO感知的瓦片式Softmax注意力。MSA的核重用了FlashAttention的算法骨架,但其循环顺序针对MSA产生的GQA原生、块粒度访问模式进行了调整。


A5 结论

本文介绍了MSA,一种与分组查询注意力(GQA)协同设计的稀疏注意力机制。该架构在标准GQA层上附加了一个轻量级的索引分支:每个GQA组通过一个块级点积索引器独立选择一小组键值块,主分支则仅在选定的块上执行Softmax注意力。索引分支是一个纯粹的选择器,通过KL对齐损失进行训练,并采用两阶段预热计划和对索引输入应用停止梯度(stop-gradient)的策略,将辅助损失限制在索引投影上。

在109B-MoE规模上,MSA在大多数预训练和智能体基准测试中保持了GQA全注意力基线的能力,同时在1M上下文中将每个词元的注意力计算量减少了28.4倍,这正是长上下文推理成为部署瓶颈的领域。

展望:MSA的核心决策——每个GQA组独立选择、块级粒度、以及用KL对齐目标训练的索引器——与当前大多数开源前沿模型共享的GQA主干兼容,因此该方法应能以很少的修改进行迁移。未来的两个自然方向是:1)通过更长的稀疏训练、推理时更大的选择预算或更丰富的索引器评分函数来弥合剩余的长上下文检索差距;2)将相同的纯选择器设计扩展到预训练之外的场景,如强化学习后训练和智能体部署,在这些场景中,长上下文成本是主要的操作限制。


A6 附录

A. 可视化

索引器选择模式:为了理解学习到的索引器选择了什么,图5可视化了每个头的索引分支在所有查询块和键块对上的选择概率。图中展示了来自早期层(第1层)和后期层(第18层)的四个头,对应四个不同的GQA组。跨层来看,学习到的稀疏模式恢复了稠密注意力的主要结构:所有头都在局部对角线上放置高概率,一致地选择“汇”列(sink column),并将剩余预算用于少数长程相对位置。同时,不同GQA组的非局部选择并不相同,它们关注不同的长程条纹,表明学习到的索引器捕获了特定于组的稀疏注意力模式,而不是收敛到单一的全局选择模式。


(a) 第1层,四个GQA组。每个组在共享的局部对角线和汇列之外,产生了不同的长程选择模式。


(b) 第18层,四个GQA组。长程选择锐化为每个组的几条条纹;四个组选择了明显不同的条纹。

注意力汇现象:我们进一步研究了MSA模型中的注意力汇(attention sink)现象。即使没有明确强制索引器选择第一个键值块,我们观察到学习到的索引分支在所有层和头上都自然地为初始块分配了很高的选择概率。图6显示了两个代表性层(第4层和第24层)的结果,每层采样八个头。在两层中,每个头都将相当一部分注意力权重导向第一个词元。这证实了即使在我们的稀疏注意力机制中,注意力的焦点也会自然出现,并且普遍存在于不同的头和层中。


图6 | 第4层和第24层中每个注意力头在第一个词元上的平均注意力分数。所有头都将显著部分的注意力分配给第一个词元,证实了跨头和层的普遍注意力汇效应。

B. 初步实验

本节介绍了在一个10B参数的试点模型上进行的小规模消融研究,旨在确定对稳定优化和强大下游性能至关重要的训练设计选择。

B.1. 实验设置:所有消融实验使用一个10B参数的试点Transformer,具有16层,架构与主论文模型类似。模型使用GQA(32个查询头,4个KV头),MoE(64个专家,top-4路由)。

B.2. 索引分支的梯度来源:为了给不可微的Top-k选择提供训练信号,我们研究了两种机制:1)让索引分支贡献一个额外的注意力输出(通过LM损失训练);2)使用辅助KL损失直接监督索引分支。实验配置包括:仅LM损失、仅KL损失、LM损失+KL损失。如图7所示,单独使用任一信号都有其弱点:仅LM损失在长上下文检索上表现不佳,仅KL损失会损害短上下文能力。LM损失+KL损失在两者之间取得了最佳平衡。


图7 | 试点设置中三种索引器训练信号相对于GQA基线的评估分数增量。正值表示优于基线,负值表示退化。

B.3. 将KL梯度限制在索引分支:默认情况下,KL梯度会流回主干网络,可能导致训练不稳定(梯度尖峰)或性能下降(自蒸馏效应)。我们通过在索引分支输入处停止KL梯度(即detach)来解决此问题。如图8和图9所示,这种分离策略避免了梯度尖峰,并消除了在通用基准上的性能退化,使训练更加稳定。


图8 | 在有无将KL梯度与主干网络分离的情况下的训练LM损失和梯度范数。分离将辅助损失限制在索引分支,避免了无分离时观察到的梯度尖峰。


图9 | 在有无将KL梯度与主干网络分离的情况下的通用基准分数。分离辅助损失减少了当KL梯度更新主干网络时观察到的通用能力退化。

B.4. 索引器预热:训练初期,主分支的注意力分布变化迅速(熵急剧下降,见图10),使得稀疏选择在初始化时很脆弱。我们采用了一个简短的索引器预热阶段:在此期间,主分支使用全注意力,而索引分支通过KL损失对全序列的主分支分布进行训练。预热后,再切换到Top-k稀疏选择。如图11所示,预热后的模型在短上下文性能和长上下文检索上都取得了更好的结果。


图10 | 早期稀疏训练期间主分支注意力分布的逐层熵。熵在最初几百步内迅速下降,然后部分恢复并稳定,这为索引器的短暂全注意力预热提供了动机。


图11 | 有无索引预热的MSA评估结果。在报告的训练范围内,索引预热提高了通用任务和长上下文检索的分数。

B.5. 可学习的注意力汇:我们测试了添加一个GPT-OSS风格的可学习注意力汇参数。如图12和图13所示,虽然这个可学习的汇吸收了一部分注意力,但并未完全消除第一个词元的汇行为,并且在下游任务的困惑度上没有带来明确或一致的改进。因此,最终配方中不包含可学习的注意力汇


图12 | 引入GPT-OSS风格汇参数后,可学习汇和第一个词元接收到的注意力。在某些头中,可学习汇吸收了大部分类汇注意力;在其他头中,第一个词元仍然是主导的汇,表明显式汇并未完全消除第一个词元的汇行为。


图13 | 在下游面向智能体评估中,有无学习注意力汇的困惑度比较。困惑度越低越好。添加可学习汇并未比默认MSA设计提供一致的优势。

B.6. 动态稀疏选择 vs. 滑动窗口:为了评估动态选择的价值,我们将MSA与一个FLOP匹配的滑动窗口基线进行比较。该基线使用固定的稀疏模式(第一个块+局部窗口)。如图14所示,在相同的稀疏选择预算下,MSA在下游智能体任务上的困惑度始终低于滑动窗口模型,表明内容相关的动态词元选择优于位置固定的稀疏模式


图14 | MSA与FLOP匹配的滑动窗口基线在下游面向智能体评估中的困惑度比较。较低的困惑度表示在相同稀疏选择预算下更好的建模性能。

C. 额外的消融研究

C.1. 块大小:我们研究了不同块大小($b_{kv}$)对模型性能和效率的影响。如表4所示,改变块大小对模型质量影响有限。PPL结果几乎不变,RULER分数也没有明显下降。这表明MSA可以使用较大的KV块来提高核效率,而质量损失有限

表4 | 不同键值块大小的困惑度和长上下文检索分数。困惑度越低越好,RULER分数越高越好。

C.2. 强制汇与局部选择:我们发现,即使移除强制选择第一个块(汇)和固定局部窗口的硬编码规则,训练后的模型仍然能自然地学习到这两种模式。如表5所示,移除这些强制选择对标准模型质量(推理、代码、PPL)和长上下文检索几乎没有影响。因此,最终配方中仅强制选择包含查询自身的那个不完整的本地块

表5 | 强制汇和局部窗口选择的消融。除非标记为↓,否则越高越好。

C.3. 索引分支的值头:初步实验表明,索引分支的值头有助于从零开始的稀疏训练。然而,在引入索引器预热后,我们进一步消融了这个值头。如表6所示,移除值头并未导致系统性的性能下降,两个变体在不同基准上各有优劣,但差异很小。这表明在有预热的情况下,值头并非关键,其主要作用是提供额外的早期训练信号。因此,出于效率考虑,最终设计去除了索引分支的值头

表6 | 索引分支值头的继续预训练消融。


方法细节中的引用汇总

本文的方法细节章节(主要为第2、3、4节)引用了以下关键文献,具体如下:

  1. 引用编号: [1]

    • 文献: Joshua Ainslie, et al. GQA: Training generalized multi-query transformer models from multi-head checkpoints. EMNLP, 2023.
    • 引用段落: 第2.1节 "因果注意力与 GQA"
    • 原文描述: 用于定义分组查询注意力(Grouped-Query Attention, GQA)的概念,即通过共享键值头来减少计算和内存开销。原文描述为“Grouped-Query Attention (Ainslie et al., 2023) uses $N_q$ query heads and reduces the number of key-value heads to $N_{kv}$, tying $g = N_q/N_{kv}$ adjacent query heads to a single shared key-value head.”
  2. 引用编号: [67]

    • 文献: Lei Wang, et al. Tilelang: A composable tiled programming model for ai systems. 2025.
    • 引用段落: 第4.1节 "索引与 TopK"
    • 原文描述: 在Top-k核的基准测试中,将本文设计的专用核与TileLang的基数选择(radix-select)Top-k实现进行性能比较。原文描述为“We compare against torch.topk and the TileLang (Wang et al., 2025) radix-select top-k on an H800 GPU...”。