Wenxuan Li, Chengruidong Zhang, Huiqiang Jiang, Yucheng Li, Yuqing Yang, Lili Qiu
A1 主要贡献
本文针对超长上下文(Ultra-Long Contexts)大语言模型(LLMs)训练中计算成本过高的问题,提出了一种名为 MTraining 的高效分布式训练框架。尽管动态稀疏注意力(Dynamic Sparse Attention, DSA)在推理阶段能有效降低成本,但在分布式训练(特别是涉及 Context Parallelism)场景下,由于 Worker 级(Worker-level) 和 Step 级(Step-level) 的负载不平衡以及通信瓶颈,直接应用 DSA 面临巨大挑战。
本文的主要贡献包括:
1. 动态稀疏训练模式(Dynamic Sparse Training Pattern):基于 RoPE 注意力的理论分析和观察,识别出训练过程中存在 "Vertical-Slash"(垂直-斜线)的稀疏模式,并设计了在线近似预算机制来动态适应这种稀疏性。
2. 平衡稀疏环状注意力(Balanced Sparse Ring Attention):提出了一种基于条带(Stripe-based)的布局设计,有效地解决了分布式环境下的 Worker 级和 Step 级负载不平衡问题。
3. 分层稀疏环状注意力(Hierarchical Sparse Ring Attention):针对异构带宽环境(节点内 vs 节点间),设计了分层通信策略,掩盖了跨节点的通信开销。
实验表明,MTraining 在 32 张 A100 GPU 上成功将 Qwen2.5-3B 模型的上下文长度从 32K 扩展至 512K。在 RULER、PG-19、InfiniteBench 和 Needle In A Haystack 等基准测试中,MTraining 在保持甚至超越基线模型精度的同时,实现了高达 6 倍 的训练吞吐量提升。
图 1:Striped 和 Zigzag Ring Attention 在 4 个 CP Workers (GPUs) 上的工作负载分布。
A3 背景知识/关键Observation/设计原则
3.1 长上下文训练是动态稀疏的
注意力的动态稀疏性在预训练 LLMs 中已有充分记录,而在训练过程中,这种现象变化更为剧烈。
- 观察结果:如 Fig 2b 所示,注意力稀疏度在不同的训练步骤和输入样本之间波动显著。不同的模型检查点(Checkpoints)即使对于相同的输入也会产生不同的稀疏模式,反映了训练的时间动态性。反之,单个检查点对不同输入也会产生多样化的稀疏区域。
- 结论:这些观察强调了在训练期间进行动态稀疏适应的必要性。
图 2:(a) 训练阶段的延迟分解。(b) 不同样本和训练步骤中 128K 上下文的 top-k (k=1024) 注意力召回率。(c-d) 训练期间注意力权重 (c) 及其梯度 (d) 的可视化。结果基于使用 4×8 A100 集群训练的 Qwen2.5-3B。
3.2 注意力训练稀疏性呈现特定模式
基于注意力计算公式,作者推导了注意力权重 ($S = QK^\top / \sqrt{d_k}, A = \text{softmax}(S)$) 及其梯度($Q, K, V$)的关系。
- 梯度依赖性:通过将 $\frac{\partial L}{\partial S}$ 代入注意力的梯度表达式,可以观察到反向传播中的所有矩阵运算(GEMMs)都依赖于注意力权重 $A$。因此,反向传播中的动态稀疏性可以视为前向阶段稀疏性的叠加。
3.3 分布式动态稀疏注意力是不平衡的
分布式动态稀疏注意力引入了单节点设置中不存在的新挑战,最显著的是 Worker 级 和 Step 级 的不平衡。
- Worker 级不平衡:如 Fig 7 所示,动态稀疏性导致不同 Workers 之间的 FLOPs 分布不均。处理较快 Workers 必须在同步屏障处空闲等待,导致不平衡。例如,使用 xAttention 在 95% 稀疏度和 32 路 Context Parallelism 下,不平衡度达到 3.17,将实际加速比降低到理论最大值的三分之一。
图 7:使用 XAttention (Xu et al., 2025) 时,不同 CP Workers 之间的计算不平衡(FLOPs)。不平衡度 = 最大值/平均值。
图 3:Step 级不平衡导致气泡的示意图,此时计算和通信无法重叠。
A2 方法细节
MTraining 旨在加速超长上下文 LLM 的分布式训练,由三个核心组件构成:适应训练期高动态稀疏性的动态稀疏训练模式、解决 Worker/Step 级不平衡的平衡稀疏环状注意力,以及利用异构带宽的分层稀疏环状注意力。
图 4:分布式场景下的 MTraining 概览。
4.1 平衡稀疏环状注意力 (Balanced Sparse Ring Attention)
在全注意力(Full Attention)和因果掩码(Causal Mask)下,Ring Attention 的 ZigZag 和 Striped 实现都能达到负载平衡。但在动态稀疏注意力设置中,它们不同的激活模式导致了严重的不平衡。
- 现有问题分析:如 Fig 5a 和 Fig 8 所示,ZigZag 沿着反核心对角线(Anti-diagonal)跨 Workers 分配计算,并随 Steps 沿对角线移动;而 Striped 则相反,沿对角线分配并沿反核心对角线移动。由于数据依赖的动态稀疏性,这导致了显著的负载不平衡。
主要组件设计:
Striped Sparse Ring Attention(条带化稀疏环状注意力):
Block-level Striped Sparse Ring Attention(块级条带稀疏环状注意力):
Step-level Balanced Ring Attention(Step 级平衡环状注意力):
图 5:4 个 CP Workers 下 Striped Ring Attention (a) 和 Hierarchical Striped Ring Attention (b) 的 Step 级计算调度。
4.2 分层平衡稀疏环状注意力 (Hierarchical Balanced Sparse Ring Attention)
Ring Attention 通常通过并发执行矩阵乘法(matmul)和通信 Kernel 来重叠计算与通信。然而,在动态稀疏性下,单 Worker 计算量的减少放大了通信开销,使其成为主要瓶颈。特别是在具有异构通信链路(如 25 GB/s IB HDR vs 300 GB/s NVLink)的分布式训练中,节点间通信往往成为瓶颈。
主要组件设计:
Inner- and Outer-Ring Hierarchical Ring Attention(内外环分层环状注意力):
Hierarchical Balanced Sparse Ring Attention(分层平衡稀疏环状注意力):
附录 B.1 动态稀疏训练模式 (Dynamic Sparse Training Pattern)
受训练期间 Vertical-Slash 模式的观察和理论验证启发(见 §3.2 和 Appendix A),本文将动态稀疏注意力扩展到训练阶段。作者提出了面向训练的动态稀疏模式,包含以下关键组件:
Online Budget Approximation(在线预算近似):
Kernel-Aware Approximation Granularity(Kernel 感知近似粒度):
对齐:由于垂直和斜线模式在 Kernel 中以不同粒度运行,近似分辨率与其匹配:
效果:这种对齐确保了预算估计与实际 Kernel 执行之间的保真度。
Algorithm 1: Dynamic Sparse Training Head
该算法描述了动态稀疏训练头的核心逻辑:
last_q 近似注意力 Ab。kv 和 Top-K 索引 iv(基于 Token 级)。ks 和 Top-K 索引 is(基于 64x64 块级池化)。ivs。A4 实验环境
硬件配置:32 张 Nvidia A100 40GB GPU(4个节点,每个节点8张卡)。
软件配置:
A4 实验结果
1. 长上下文扩展训练 (Long-context Extension Training)
结果与分析:
Training Loss (Fig 6a):
Throughput (Fig 6b):
图 6:Qwen2.5-3B 在 ProLong 数据集上进行 512K 上下文窗口持续预训练期间,不同方法的训练 Loss 和吞吐量比较。
2. 长上下文下游任务 (Long-context Downstream Tasks)
3. 效率分析 (Efficiency Analysis)
图 13:在 32 个 GPU 上使用不同方法处理 512K Tokens 时的注意力计算时间分布:(a) 固定 Ring Attention 步骤内跨 CP Workers 的分布,(b) 固定 Worker 跨 Ring Attention 步骤的分布。
A5 结论
本文提出的 MTraining 框架通过解决 Worker 级和 Step 级的负载不平衡问题,成功实现了动态稀疏注意力在分布式设置下的大规模扩展。MTraining 包含三个关键组件:动态稀疏训练模式、平衡稀疏环状注意力和分层稀疏环状注意力。实验证明,MTraining 能够将 Qwen2.5-3B 高效扩展至 512K 上下文窗口,在 32 张 A100 GPU 上实现了高达 6 倍的吞吐量提升,同时在多个长上下文基准测试中保持或提升了模型精度。
A6 附录
A. 理论证明 (Proof of Theory)
A.1 注意力梯度
作者将 Vertical-Slash 模式的出现归因于 RoPE 的使用。定义 $z_{n,m}$ 为位置 $n, m$ 处 RoPE 变换后的查询和键向量的点积。
定理 A.1:应用 RoPE 后,注意力权重的期望仅依赖于相对位置 $n-m$。即 $E[z_{n,m}] = \sum_{i=0}^{d-1} \phi(i){n-m} A_i + \sum B_i$。}^{d-1} \psi(i)_{n-m
基于定理 A.1,得出两个关键见解:
A.2 定理 3.1 的详细推导
通过定义三角基函数并对 Key 向量建模为随机变量(包含均值部分和波动部分),推导了点积 $z_{n,m}$ 的期望。如 Equation 14 所示,点积期望是 $(n-m)$ 的多个正弦函数的叠加,从理论上支持了 Vertical-Slash 模式的必然性。
Algorithm 2: Balanced Sparse Ring Attention fuse w/ Hierarchical Sparse Ring Attention
该算法提供了融合分层通信的平衡稀疏环状注意力的伪代码:
- 输入包括 World size、Rank、数据 Q/K/V 以及垂直/斜线索引。
- 外环循环:处理跨节点的 P2P 通信(P2Pouter)。发送本地 KV 到下一个外环节点,接收前一个节点的 KV。
- 内环循环:在等待外环通信的同时,处理节点内的 P2P 通信(P2Pinner)和计算。
- 计算核心:调用 block_bar_sparse_attention_forward 执行前向计算,并合并输出。
- 通过这种嵌套循环结构,实现了计算与跨节点通信的重叠。
A7 补充细节
基线实现细节
1. MoBA:将 KV 序列划分为固定大小的块,对每个 Query 使用 MoE 风格的 Gate 选择 Top-K 相关块(始终包含 Query 自身所在的块)。实验中块大小设为 4096,TopK 为 12,512K 上下文下的稀疏率为 0.9。代码被适配以运行 Zigzag Ring Attention。
2. XAttention:通过沿反核心对角线每隔一定步长求和来对块进行评分,仅保留高分块。实验设置块大小为 128,步长 16,阈值 0.9。
ZigZag 调度可视化
Fig 8 补充展示了 ZigZag Ring Attention 的 Step 级计算调度,用于与 Striped Ring Attention 进行对比,直观显示了其在稀疏设置下导致的不平衡模式(沿反核心对角线分布)。
图 8:Zigzag Ring Attention 的 Step 级计算调度。