文章标题: MTraining: 用于高效超长上下文训练的分布式动态稀疏注意力
作者/机构: Wenxuan Li (剑桥大学, 微软研究院), Chengruidong Zhang (微软研究院), Huiqiang Jiang (微软研究院), Yucheng Li (萨里大学), Yuqing Yang (微软研究院), Lili Qiu (微软研究院)
核心问题:随着大语言模型(LLM)的上下文窗口不断扩展,注意力计算的二次方复杂度导致训练成本急剧增加。当上下文超过300K tokens时,注意力计算的前向和后向传播占据了超过90%的训练开销(如图2a所示)。动态稀疏注意力是一种有前景的降低计算成本的方法,但在分布式训练(如上下文并行)中,由于存在工作节点(worker)和步骤(step)级别的不平衡问题,其效率远低于理论值。
研究目标:本文旨在解决分布式训练中动态稀疏注意力的效率问题,实现其线性扩展,从而显著加速超长上下文LLM的训练。
创新点:本文提出了MTraining,一个算法-系统协同设计的框架,通过集成了三个关键组件来解决计算不平衡和通信开销问题:
1. 动态稀疏训练模式:通过经验观察和理论验证,发现带有旋转位置编码(RoPE)的注意力权重在训练中呈现出独特的“垂直-斜线(Vertical-Slash)”局部性模式。基于此,本文提出了一种在线近似稀疏预算机制,以在训练期间动态适应稀疏模式。
2. 平衡的稀疏环形注意力(Balanced Sparse Ring Attention):基于条带化环形注意力(Striped Ring Attention),设计了一种与观察到的稀疏结构对齐的块级平衡稀疏环形注意力机制,以解决工作节点级和步骤级的负载不平衡问题。
3. 分层稀疏环形注意力(Hierarchical Sparse Ring Attention):为了在异构分布式网络中进一步减少通信开销,本文采用了一种分层设计,将全局环形通信分解为节点内(inner ring)和节点间(outer ring)两个层次,有效隐藏了节点间的通信延迟。
通过这些设计,MTraining能够实现动态稀疏注意力的近线性扩展,在32个A100 GPU上将Qwen2.5-3B模型的上下文窗口从32K扩展到512K,训练吞吐量最高提升6倍,同时保持了模型的准确性。
环形注意力(Ring Attention)。长上下文训练越来越受限于注意力延迟。环形注意力【【19,Ring attention with blockwise transformers for near-infinite context,2024,The Twelfth International Conference on Learning Representations】、 【20,Striped attention: Faster ring attention for causal transformers,2023,arXiv】】通过将长序列分布在不同设备上,并将键值(key-value)通信与分块注意力计算【【26,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,Advances in neural information processing systems】】重叠,提高了可扩展性,使得序列长度可以随设备数量扩展。存在两种主要变体:ZigZag【【27,[Feature request] balancing computation with zigzag blocking,2024,GitHub issue #2; accessed 13 May 2025】】和Striped【【20,Striped attention: Faster ring attention for causal transformers,2023,arXiv】】。如图1所示,ZigZag折叠了查询(query)维度,并在工作节点间镜像数据块,而Striped则按行或块循环划分查询。计算过程中,每个工作节点的Q(查询)和O(输出)保持固定,而K(键)和V(值)通过点对点(P2P)通信进行循环传递——这对于分组查询注意力(Grouped Query Attention)至关重要。在因果全注意力(causal full attention)下,两种变体都能在工作节点间保持均衡的工作负载。
注意力矩阵的动态稀疏性。预训练LLM中注意力矩阵的动态稀疏性——尤其是在长上下文设置下——已有充分的文献记载【【13,QUEST: Query-aware sparsity for efficient long-context LLM inference,2024,Forty-first International Conference on Machine Learning】、 【14,Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention,2024,Advances in Neural Information Processing Systems】、 【15,Flexprefill: A contextaware sparse attention mechanism for efficient long-sequence inference,2025,The Thirteenth International Conference on Learning Representations】、 【28,Xattention: Block sparse attention with antidiagonal scoring,2025,arXiv】】。这种现象在训练过程中持续存在,并且通常具有更大的可变性。如图2b所示,注意力稀疏度在不同训练步骤和输入样本之间波动显著。不同的模型检查点即使对于相同的输入也会产生不同的稀疏模式,反映了训练过程中的时间动态性。反之,单个检查点在不同输入上可能产生多样的稀疏区域。这些观察强调了在训练期间进行动态稀疏性适应的必要性。
反向传播中的稀疏性与前向传播的关联。基于注意力计算公式,我们可以推导出注意力权重($S = QK^T / \sqrt{d_k}$, $A = \text{softmax}(S)$)以及Q、K和V的梯度,如公式1和公式2所示。
$$\frac{\partial \mathcal{L}}{\partial S}=A \odot\left(\frac{\partial \mathcal{L}}{\partial A}-\sum_{j} \frac{\partial \mathcal{L}}{\partial A_{i j}} A_{i j}\right)$$反向传播中的稀疏性继承。通过将$\frac{\partial L}{\partial S}$代入注意力的梯度表达式(公式2),我们观察到反向传播中的所有矩阵运算(即GEMM)都依赖于注意力权重$A$。因此,反向传播中的动态稀疏性可以看作是前向传播稀疏性的叠加。如图2c和图2d所示,梯度$\frac{\partial L}{\partial S}$展现出的稀疏模式与前向传播中的模式非常相似。值得注意的是,反向梯度显示出结构化稀疏性,在整个训练过程中始终遵循一种“垂直-斜线(Vertical-Slash)”的局部性模式。
RoPE导致Vertical-Slash模式。我们进一步将这种模式的出现归因于相对位置嵌入的使用,特别是RoPE【【29,Roformer: Enhanced transformer with rotary position embedding,2024,Neurocomputing】】。设查询向量$q_n \in R^{1 \times d}$和键向量$k_m \in R^{1 \times d}$分别表示长度为$N$的序列中位置$n, m \in \{0, \dots, N-1\}$的token表示。我们定义$z_{n,m}$为经过RoPE变换后的位置$n$的查询向量和位置$m$的键向量的点积。
定理3.1。应用RoPE后的注意力权重的期望值仅依赖于相对位置$n-m$,即$E[z_{n,m}] = \sum_{i=0}^{d-1} \phi^{(i)}_{n-m} A_i + \sum_{i=0}^{d-1} \psi^{(i)}_{n-m} B_i$。
定理推论。基于定理3.1(证明见附录B),我们得出两个关键见解:1) 带有RoPE的注意力矩阵呈现出“垂直-斜线”覆盖模式。“斜线”结构源于期望注意力权重对相对位置$n-m$的依赖,而“垂直”部分则是由查询/键分布中的异常值造成的,如公式10所述;2) 带有RoPE的注意力矩阵倾向于形成带状稀疏激活模式。由于$\phi^{(i)}_{n-m}$和$\psi^{(i)}_{n-m}$在相对位置$n-m$上是连续的,且$E[z_{n,m}]$中的系数$A_i$和$B_i$与位置无关,因此激活倾向于局部聚集在特定的相对位置周围。
工作节点级和步骤级不平衡。分布式动态稀疏注意力引入了单节点设置中不存在的新挑战——最突出的是工作节点级(worker-level)和步骤级(step-level)的不平衡。如图3a所示,动态稀疏性导致不同工作节点的浮点运算量(FLOPs)不均匀,从而引起工作节点级不平衡,其中速度较快的工作节点因同步屏障而空闲。例如,使用xAttention【【28,Xattention: Block sparse attention with antidiagonal scoring,2025,arXiv】】在95%的稀疏度和32路上下文并行下,不平衡度达到3.17,这使得实际加速比降至理论最大值的三分之一。
步骤级不平衡导致通信-计算重叠困难。相比之下,步骤级不平衡指的是单个工作节点在环形注意力(Ring Attention)的不同步骤中计算负载的波动,这是由变化的稀疏模式和样本复杂性驱动的。如图3b所示,这种变化导致工作负载随时间不均匀。当计算量因高稀疏度而减少时,其耗时可能低于通信延迟,使得计算与通信更难重叠,从而导致性能下降的“气泡(bubbles)”。
图3:分布式LLM训练中由动态稀疏注意力引入的工作节点级和步骤级负载不平衡问题的图示。
MTraining框架概述。基于第3节的分析,我们提出了MTraining,以加速超长上下文LLM的分布式训练。MTraining包含三个组件:1) 动态稀疏训练模式,专为训练中观察到的高度动态稀疏性而设计;2) 平衡的稀疏环形注意力,采用基于条带(stripe)的布局来解决工作节点级和步骤级的不平衡问题;3) 分层稀疏环形注意力,利用InfiniteBranch拓扑中的异构节点内/节点间带宽进行高效的稀疏通信。
训练阶段的Vertical-Slash模式。受到训练期间“垂直-斜线”模式的经验观察和理论验证(见§3.2和附录B)的启发,我们将这种最初为推理设计的动态稀疏注意力【【14,Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention,2024,Advances in Neural Information Processing Systems】、 【15,Flexprefill: A contextaware sparse attention mechanism for efficient long-sequence inference,2025,The Thirteenth International Conference on Learning Representations】】扩展到训练阶段。在MInference和FlexPrefill的基础上,我们提出了一种新颖的、以“垂直-斜线”结构为指导的面向训练的动态稀疏模式。如算法1所详述,我们的方法包含两个关键组成部分:
在线预算近似与核函数感知粒度。
* (i) 在线预算近似(Online Budget Approximation)。为了适应训练步骤和上下文之间稀疏模式的动态变化,并消除离线搜索的开销,我们提出了一种在线预算近似方法。具体来说,我们跟踪一个观察窗口内的注意力权重统计数据,并估算召回目标比例的注意力权重所需的最少垂直线和斜线条数。
* (ii) 核函数感知的近似粒度(Kernel-Aware Approximation Granularity)。由于垂直线和斜线模式在核函数中以不同的粒度运行,我们相应地匹配近似分辨率:垂直线在token级别进行估计,而斜线则在64x64的块上进行池化。这种对齐确保了预算估计与实际核函数执行之间的一致性。
算法1 动态稀疏训练头
输入: Q, K, V ∈ R^(S×d_h), pv, ps ∈ [0, 1],
Bs ∈ N
# 使用 last_q 近似注意力
 ← softmax(Q[-last_q:]K^T / √d + m_casual)
# 在线近似垂直预算 kv
# 和 Top-K 索引
kv ← topp(sum_v(Â), pv)
iv ← argtopk(sum_v(Â), kv)
# 在线近似斜线预算 ks
# 和 Top-K 索引
ks ← topp(Pool(sum_v(Â), Bs), ps)
is ← argtopk(sum_s(Â), ks)
# 构建稀疏注意力索引
ivs ← sparseformat(iv, is)
# 动态稀疏 Flash-Attention
y ← sparse(softmax(QK^T / √d)V, ivs)
返回 y
动态稀疏下Ring Attention的不平衡问题。如§2和§3.3所讨论,ZigZag和Striped两种Ring Attention的实现在带有因果掩码的全注意力下都能实现均衡计算。然而,在动态稀疏注意力的设置中,它们不同的激活模式会导致工作节点级和步骤级的不平衡。如图5a和图13所示,ZigZag沿反对角线在工作节点间分布计算,并沿对角线在步骤间移动;而Striped则相反:沿对角线分布,沿反对角线移动。这些不同的时空模式在动态的、数据依赖的稀疏性下会导致严重的负载不平衡。
平衡稀疏环形注意力的设计。为了解决这个问题,我们提出了平衡稀疏环形注意力(Balanced Sparse Ring Attention),这是一种系统-算法协同设计方法,包含以下关键组件:
* (i) 条带化稀疏环形注意力(Striped Sparse Ring Attention)。如§3.2和§4.1所示,训练期间基于RoPE的注意力主要呈现出“垂直-斜线”稀疏模式,其中由于块级GPU操作,斜线部分主导了计算。为了平衡工作节点间的工作负载,我们将它们沿对角线方向对齐,并提出了一种条带化的动态稀疏环形注意力方案。如图5a所示,这种设计将斜线均匀地分布在工作节点间,使每个工作节点在每一步都能处理连续的斜线区域。
* (ii) 块级条带化稀疏环形注意力(Block-level Striped Sparse Ring Attention)。由于斜线操作的块级计算及其空间局部性,我们引入了块级条带化稀疏环形注意力。我们采用64个token的条带粒度,以保持相干性,避免因token级条带化带来的碎片化,并维持核的稀疏性和效率。这种对齐还减少了索引开销并改善了运行时性能。
* (iii) 步骤级平衡环形注意力(Step-level Balanced Ring Attention)。我们的块级条带化设计也缓解了步骤级的不平衡。在超长上下文设置中,工作节点在每一步处理细粒度的条带——例如,有128个工作节点和512K序列,每个工作节点顺序处理64个块条带。这种重复的、细粒度的划分稳定了各步骤的计算,确保了更一致的工作负载分布。
动态稀疏下通信成为瓶颈。环形注意力通常通过并行执行矩阵乘法(matmul)和通信核来重叠计算与通信【【19,Ring attention with blockwise transformers for near-infinite context,2024,The Twelfth International Conference on Learning Representations】】。然而,在动态稀疏性下,每个工作节点的计算量减少,放大了通信开销,使其成为主要瓶颈。因此,在稀疏机制下,降低通信成本对于高效的分布式训练至关重要。
异构网络中的通信瓶颈与分层设计思想。在具有异构通信链路的分布式训练中,节点间通信通常成为环形注意力的瓶颈。例如,节点间带宽(如25 GB/s的InfiniBand HDR)通常比节点内链路(如300 GB/s的NVLink 3.0或PCIe 5.0)慢3-12倍。近期的工作【【9,Deepseek-v3 technical report,2024,arXiv】、 【30,Loongtrain: Efficient training of long-sequence llms with head-context parallelism,2024,arXiv】】探索了分层通信拓扑以在这种带宽不对称下减少延迟。受【【30,Loongtrain: Efficient training of long-sequence llms with head-context parallelism,2024,arXiv】】启发,我们提出了分层平衡稀疏环形注意力(Hierarchical Balanced Sparse Ring Attention),以减轻稀疏环形注意力中的节点间通信开销。
分层平衡稀疏环形注意力的具体设计。具体来说,如图5b所示,我们的方法包含以下设计:
* (i) 内外环分层环形注意力(Inner- and Outer-Ring Hierarchical Ring Attention)。我们将全局环形通信分解为两个层次:一个内环和一个外环。在内环中,键值(KV)块在每个计算节点内的$G_{\text{node}}$个GPU之间循环。外环通过交换聚合的KV缓冲区来处理$N_{\text{node}}$个节点间的通信。在每个外环步骤中,调度过程如下:1) 发布外环P2P。启动一个非阻塞的P2P通信操作,将本地节点的当前KV块传输到下一个节点,并发布一个匹配的接收请求。2) 内环注意力。在节点间传输进行的同时,GPU进入一个长度为$G_{\text{node}}$的循环,在节点内的本地KV切片上执行稀疏环形注意力计算。3) 同步。在每个外环步骤结束时,计算和通信进行同步,然后进入下一个外环迭代。
* (ii) 分层平衡稀疏环形注意力(Hierarchical Balanced Sparse Ring Attention)。与全注意力不同,在稀疏设置中应用分层环形注意力会改变键/值块在工作节点间的传播顺序,可能影响注意力计算模式。然而,如图5b所示,即使采用两级KV传输(内环和外环),计算在各步骤中仍然保持对角线对齐,保留了“垂直-斜线”模式并维持了负载平衡。通过将这种分层设计集成到MTraining的稀疏环形注意力中,节点间的KV传输被内环计算完全重叠,有效减轻了由节点间数据移动引起的通信开销。
硬件配置:
软件配置:
实验内容:在ProLong数据集上对Qwen2.5-3B进行长上下文扩展训练(从32K到512K),比较MTraining与多种基线方法(密集注意力、MoBA等)的训练损失和吞吐量。
实验结果:
实验内容:在RULER、Needle In A Haystack (NIAH)、PG-19和InfiniteBench等多个长上下文基准上,评估使用MTraining训练的Qwen2.5-3B模型的性能。
实验结果:
- RULER:如表1所示,MTraining在不同上下文长度下均优于基线。与密集训练相比,当推理时分别使用密集注意力和MInference时,MTraining的总体性能分别提升了3%和13.4%。
- Needle In A Haystack (NIAH):如图7所示,MTraining在NIAH测试中实现了近乎完美的检索性能,总体检索准确率优于基线模型。
- PG-19 (语言建模):如图8所示,MTraining在不同上下文长度下保持了与密集基线相当的困惑度(perplexity)。
- InfiniteBench:如表2所示,MTraining在InfiniteBench基准上表现优于密集基线,特别是在编码和摘要能力上有所提升,同时在问答任务上保持了竞争力。
表格与图表引用:
表1:在16K到512K上下文长度下,使用长上下文扩展的Qwen2.5-3B在RULER [22]上的各种训练-推理组合的性能(%)。
表2:在InfiniteBench [24]上的性能(%)。
实验内容:分析MTraining在减少工作节点级和步骤级不平衡方面的效果。
实验结果:
- 如图12所示,MTraining显著降低了动态稀疏注意力的工作节点级和步骤级不平衡。最大与平均计算时间之比分别下降了2.4倍和2.3倍。
- 其中,平衡稀疏环形注意力将工作节点级不平衡降低了2.1倍,步骤级不平衡降低了2.2倍。
- 分层稀疏环形注意力进一步将工作节点级不平衡降低了1.2倍,步骤级不平衡降低了1.03倍。
本文提出了MTraining,一种利用动态稀疏注意力来支持超长上下文LLM高效大规模训练的分布式训练方法。MTraining通过其三个核心组件——动态稀疏训练模式、平衡稀疏环形注意力和分层稀疏环形注意力,成功解决了分布式动态稀疏注意力中的工作节点级和步骤级不平衡的关键挑战。实验表明,平衡和分层设计分别将工作节点级不平衡降低了2.1倍和1.2倍,将步骤级不平衡降低了2.2倍和1.03倍。通过在32个A100 GPU上将Qwen2.5-3B的上下文窗口扩展到512K的实验验证,MTraining在处理512K token长度的数据时,实现了高达6倍的训练吞吐量提升,同时保持甚至提高了模型在多个长上下文基准测试(RULER、PG-19、InfiniteBench和大海捞针)上的准确性。
更大规模模型的验证。为了更全面地评估模型大小和训练token数量更大规模下的效果,我们使用MTraining和密集注意力在ProLong数据集【【12,How to train long-context language models (effectively),2024,arXiv】】上对Llama-3.1-8B-Instruct【【11,The llama 3 herd of models,2024,arXiv】】进行了2B tokens、512K token长度的训练。其他设置,包括模型初始化、RoPE、学习率和优化器,均遵循【【12,How to train long-context language models (effectively),2024,arXiv】】中的最终配方(持续预训练的第二阶段)。训练损失和在RULER上的下游评估结果分别呈现在图10和表3中。训练损失曲线表明,当应用于8B规模的模型时,MTraining与密集注意力相比仍然只有微小的训练损失差距,并在整个训练过程中保持相同的趋势。此外,使用MTraining训练的模型在RULER上仍然表现出几乎无损的性能,并对稀疏推理具有更好的鲁棒性。这些结果为我们的稀疏注意力算法在训练更大、不同架构模型时的普适性提供了经验支持。
表3:在16K到512K上下文长度下,使用长上下文扩展的Llama-3.1-Instruct-8B在RULER [22]上的各种训练-推理组合的性能(%)。
RoPE点积计算。设 $\vec{q}_n \in \mathbb{R}^{1 \times d}$ 和 $\vec{k}_m \in \mathbb{R}^{1 \times d}$,其中 $n, m \in [0, N)$,分别为应用RoPE之前的查询和键向量。应用RoPE后,它们的点积 $z_{n,m}$ 计算如下:
$$z_{n,m} = \operatorname{RoPE}(\vec{q}_n, n) \operatorname{RoPE}(\vec{k}_m, m)^T = \vec{q}_n \vec{W}_n \vec{W}_m^T \vec{k}_m^T = \vec{q}_n \vec{W}_{n-m} \vec{k}_m^T,$$点积简化。根据旋转矩阵的定义,点积 $z_{n,m}$ 可以进一步简化如下:
$$\begin{aligned} \begin{aligned} z_{n, m} & =\vec{q}_n \vec{W}_{n-m} \vec{k}_m^T \\ & =\vec{q}_n^{\left[0: \frac{d}{2}\right]} \cos ((n-m) \vec{\theta})\left(\vec{k}_m^{\left[0: \frac{d}{2}\right]}\right)^T+\vec{q}_n^{\left[\frac{d}{2}: d\right]} \cos ((n-m) \vec{\theta})\left(\vec{k}_m^{\left[\frac{d}{2}: d\right]}\right)^T \\ & +\vec{q}_n^{\left[0: \frac{d}{2}\right]} \sin ((n-m) \vec{\theta})\left(\vec{k}_m^{\left[\frac{d}{2}: d\right]}\right)^T-\vec{q}_n^{\left[\frac{d}{2}: d\right]} \sin ((n-m) \vec{\theta})\left(\vec{k}_m^{\left[0: \frac{d}{2}\right]}\right)^T, \end{aligned} \end{aligned}$$其中 $\vec{q}^{[a:b]}_n$ 是 $\vec{q}_n$ 从第 $a$ 个元素(含)到第 $b$ 个元素(不含)的子向量。$\vec{k}^{[a:b]}_m$ 的定义类似。通过定义三角基函数:
$$\phi_{n-m}^{(i)}=\cos ((n-m) \theta_{i \% \frac{d}{2}}), \quad \text { and } \quad \psi_{n-m}^{(i)}=(-1)^{i \geq \frac{d}{2}} \sin ((n-m) \theta_{i \% \frac{d}{2}}),$$公式4可以进一步简化为:
$$z_{n,m} = \sum_{i=0}^{d-1} \phi_{n-m}^{(i)} q_n^{(i)} k_m^{(i)} + \sum_{i=0}^{d-1} \psi_{n-m}^{(i)} q_n^{(i)} k_m^{(i + \frac{d}{2} \% \frac{d}{2})}.$$键向量的随机变量建模。我们将键向量 $\vec{k}_m$ 建模为一个随机变量,如下所示:
其中 $\mu_k^{(i)} = E_{m \in [0,N)}[k_m^{(i)}]$ 是键向量第 $i$ 个通道在所有位置上的均值,而 $\chi_m^{(i)}$ 是均值为零、方差为 $\sigma_i^2$ 的随机变量。
点积分解。通过用随机变量模型替换键向量,公式6中的点积得分 $z_{n,m}$ 可以进一步简化为两部分,均值部分 $\bar{z}_{n,m}$ 和波动部分 $\tilde{z}_{n,m}$:
$$z_{n, m}=\bar{z}_{n, m}+\tilde{z}_{n, m},$$其中均值部分 $\bar{z}_{n,m}$ 是:
$$\bar{z}_{n,m} = \sum_{i=0}^{d-1} \phi_{n-m}^{(i)} q_{n}^{(i)} \mu_{k}^{(i)} + \sum_{i=0}^{d-1} \psi_{n-m}^{(i)} q_{n}^{(i)} \mu_{k}^{(i + \frac{d}{2} \% \frac{d}{2})},$$波动部分 $\tilde{z}_{n,m}$ 是:
$$\tilde{z}_{n, m}=\sum_{i=0}^{d-1} \phi_{n-m}^{(i)} q_{n}^{(i)} \chi_{m}^{(i)}+\sum_{i=0}^{d-1} \psi_{n-m}^{(i)} q_{n}^{(i)} \chi_{m}^{\left(i+\frac{d}{2} \% \frac{d}{2}\right)} .$$注意力得分计算。注意力得分 $a_{n,m}$ 是通过对点积得分 $z_{n,m}$ 逐行应用softmax函数计算得到的:
$$a_{n,m} = \frac{\exp(z_{n,m})}{\sum_{j=0}^{L-1} \exp(z_{n,j})},$$其中 $L$ 是序列的长度。
查询和键的分布。我们假设查询和键来自一个随机分布,其均值为 $E[q_n^{(i)}]$ 和 $E[k_m^{(i)}]$,协方差为 $\sigma_{i,j}$,如下所示:
$$\sigma_{i,j} = E[(q_n^{(i)} - E[q_n^{(i)}])(k_m^{(j)} - E[k_m^{(j)}])].$$乘积的期望。乘积 $q_n^{(i)}k_m^{(j)}$ 的期望如下:
$$E[q_{n}^{(i)} k_{m}^{(j)}] = \mu_{i,j}^2 + \sigma_{i,j}.$$其中 $\mu_{i,j}^2 = E[q_n^{(i)}]E[k_m^{(j)}]$ 是 $q_n^{(i)}$ 和 $k_m^{(j)}$ 均值的乘积。因此,公式6中点积 $z_{n,m}$ 的期望如下:
$$\begin{aligned} \begin{aligned} E[z_{n,m}] &= \sum_{i=0}^{d-1} \phi_{n-m}^{(i)} E[q_{n}^{(i)} k_{m}^{(i)}] + \sum_{i=0}^{d-1} \psi_{n-m}^{(i)} E[q_{n}^{(i)} k_{m}^{((i+\frac{d}{2})\% \frac{d}{2})}] \\ &= \sum_{i=0}^{d-1} \phi_{n-m}^{(i)} (\mu_{i,i}^{2} + \sigma_{i,i}) + \sum_{i=0}^{d-1} \psi_{n-m}^{(i)} (\mu_{i,(i+\frac{d}{2})\% \frac{d}{2}}^{2} + \sigma_{i,(i+\frac{d}{2})\% \frac{d}{2}}). \end{aligned} \end{aligned}$$如公式14所示,点积 $z_{n,m}$ 的期望是多个关于 $(n-m)$ 的正弦函数的叠加。
延迟分解与通信开销分析。为了提供更详细的延迟分解和通信量数据,表4提供了在4节点设置下训练Qwen2.5-3B单次注意力计算的延迟分解。同时,以下对前向过程的分析展示了MTraining的优势:
形式化分析。正式地,令 $T_{\text{comp}} = 0.51$ms 为内环计算时间, $T_{\text{intra}} = 0.13$ms 为节点内(NVLink)通信时间, $T_{\text{inter}} = 0.98$ms 为节点间(InfiniBand)通信时间。
无分层设计的延迟。在没有分层设计的情况下,每个环形注意力步骤的通信时间为:
$$T_{step} \approx max\{T_{comp}, T_{intra}, T_{inter}\} = T_{\text{inter}} = 0.98ms$$考虑到通信发生 $32-1=31$ 次,总时间包括稀疏索引构建(1.13ms)、CPU操作(2.08ms)和最后一次注意力计算(0.51ms),总计达到34.10ms。
分层设计的延迟。通过有效重叠节点间通信和内环注意力,分层平衡稀疏环形注意力使得每一步的延迟由以下公式决定:
$$T_{step}^{hier} \approx max\{T_{comp}, T_{inner}\} = T_{comp} = 0.51ms$$这种情况发生32次。加上稀疏索引构建和CPU操作时间,总时间达到19.53ms,削减了42.7%的前向注意力时间。类似的分析也适用于反向过程。结合图6,这证实了我们的方法在最小化端到端训练时间方面的有效性。
表4:MTraining中单次注意力计算的前向和后向传播延迟分解。
表5:不同训练策略的平均不平衡度(ID)和计算比率。
ZigZag Ring Attention。图13提供了ZigZag环形注意力步骤级计算调度的可视化,补充了图5中对条带化环形注意力和分层平衡稀疏环形注意力的说明。
分层平衡稀疏环形注意力的伪代码。所实现的伪代码见算法2。
算法2 分层平衡稀疏环形注意力
输入:
世界大小和排名: w_outer, w_inner, r
输入数据: Q, K, V
垂直和斜线索引: Iv, Is
# 为当前排名转换稀疏索引
I_block, I_bar = convert_index(Iv, Is, w_outer * w_inner, r)
# 外环
for i ← 1 to w_outer do
if i < w_outer then
# # 开始外环通信
next_outer_rank = (r + w_inner) % (w_outer * w_inner)
P2P_outer.async_send(K, next_outer_rank)
P2P_outer.async_send(V, next_outer_rank)
prev_outer_rank = (r - w_inner) % (w_outer * w_inner)
K'' = P2P_outer.async_recv(prev_outer_rank)
V'' = P2P_outer.async_recv(prev_outer_rank)
end
# # 内环
for j ← 1 to w_inner do
if j < w_inner then
# # 开始内环通信
next_inner_rank = (r + 1) % w_inner
P2P_inner.async_send(K, next_inner_rank)
P2P_inner.async_send(V, next_inner_rank)
prev_inner_rank = (r - 1) % w_inner
K' = P2P.async_recv(prev_inner_rank)
V' = P2P.async_recv(prev_inner_rank)
end
# # # 稀疏注意力计算
Out', LSE' ← block_bar_sparse_attention_forward(Q, K, V, I_block[i * w_inner + j], I_bar[i * w_inner + j])
Out, LSE ← merge_out_and_lse(Out, LSE, Out', LSE')
if j < w_inner then
# 等待内环通信
P2P_inner.wait()
K ← K', V ← V'
end
end for
if i < w_outer then
# 等待外环通信
P2P_outer.wait()
K ← K'', V ← V''
end
end for
额外的实现细节。所有实验都在一个4x8的NVIDIA A100-40GB集群上进行,其中每个节点内的八个GPU通过NVLink通信,节点间通过HDR InfiniBand互连。由于本研究旨在独立评估上下文并行的好处,因此在训练和性能分析中,每个GPU都仅作为CP工作节点,没有启用额外的数据、流水线或张量并行。我们使用nnScaler框架【【31,{nnScaler}:{Constraint-Guided} parallelization plan generation for deep learning training,2024,18th USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)】】,该框架首先将模型追踪为计算图,然后搜索最优的并行执行计划;其搜索空间被限制为只将所有GPU分配给CP。训练使用ZeRO-2【【32,Zero: Memory optimizations toward training trillion parameter models,2020,SC20: International Conference for High Performance Computing, Networking, Storage and Analysis】】、64步梯度累积【【33,Gpipe: Efficient training of giant neural networks using pipeline parallelism,2019,Advances in neural information processing systems】】、模型权重、梯度和激活使用bfloat16精度,优化器状态使用float32精度;优化器为Adam【【68,Adam: A method for stochastic optimization,2014,arXiv】】;梯度检查点和重计算【【34,Training deep nets with sublinear memory cost,2016,arXiv】】被应用于峰值激活内存。效率分析会话复制了相同的并行执行配置。MTraining中的自注意力是使用基于FlashAttention【【26,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,Advances in neural information processing systems】】、BlockSparse【【35,Block Sparse Attention,2024,https://github.com/mit-han-lab/Block-Sparse-Attention】】和PIT动态稀疏编译器【【36 ,Pit: Optimization of dynamic sparse deep learning models via permutation invariant transformation,2023,Proceedings of the 29th Symposium on Operating Systems Principles】】的自定义CUDA核实现的。对于外部稀疏算法如MoBA和XAttention,我们修改了它们的原始代码,使其能在Zigzag Ring-Attention调度下运行。
基线细节。
1) MoBA 【【18,Moba: Mixture of block attention for long-context llms,2025,arXiv】】。MoBA将键值序列划分为固定大小的块,并为每个查询使用一个MoE风格的门来选择top-k个最相关的块(总是包括查询自身的块),然后在每个选定的块内运行FlashAttention。在我们的实验中,块大小设置为4096,topK值为12,使得在512K上下文下的稀疏比为0.9。我们调整了其官方仓库发布的实现,使其能与Zigzag Ring Attention一起运行。但由于官方发布代码的效率不佳,我们在效率相关的实验中忽略了与它的比较。
2) XAttention 【【28,Xattention: Block sparse attention with antidiagonal scoring,2025,arXiv】】。XAttention通过沿方块的反对角线以一定步长求和来对方块进行评分,并只保留高分块,提供了一种即插即用、无需训练的块稀疏注意力,可加速预填充(prefill)同时保持与密集注意力相当的准确性。在我们的实验中,我们使用以下设置:粒度为128作为块大小,步长16作为采样间距,阈值为0.9用于选择块。