发表时间: 2026-06 · arXiv:2606.08476
文章标题:FlashCP: 面向LLM训练的负载均衡、通信高效的上下文并行
作者/机构:Zheng Wang 1, Eric Liu 2, Linan Jiang 1, Zhongkai Yu 1, Zaifeng Pan 1, Yue Guan 1, Yuke Wang 3, Yufei Ding 1
本文旨在解决大规模、长上下文语言模型(LLM)训练中上下文并行(Context Parallelism, CP)面临的挑战。
核心问题:现有的CP方法通过对序列进行分区来减少内存开销,但存在以下三个核心问题:
1. 工作负载不均衡:由于每个token的注意力计算量随其在序列中的位置变化而不同,静态的序列分片会导致GPU间的工作负载不均衡。
2. 通信效率低下:CP需要跨工作节点(worker)通信键值(KV)张量,现有方法会传输整个KV张量,导致大量冗余通信。
3. 计算核效率低:像FlashAttention这样的高效注意力计算核依赖于足够长的查询(Query)和键值(KV)长度才能实现高GPU利用率,而输入序列被分区后,这种效率难以维持。
研究目标:设计一个负载均衡且通信高效的CP训练框架,能够同时解决上述三个问题,以实现近乎最优的CP训练性能。
创新点 (FlashCP):
FlashCP通过整体优化分片策略和通信流,实现了工作负载均衡、最大化注意力核效率和最小化通信开销。
1. 分片感知的通信机制 (Sharding-aware communication mechanism):该机制能够识别每个CP工作节点实际需要的KV张量部分,从而只通信必要的数据,消除了冗余的数据传输。
2. 全新的Whole-Doc分片策略:该策略将较短的文档作为一个整体保留在单个CP工作节点上,从而最大程度地减少了通信需求并保持了高计算核效率。对于剩余的文档,它会进行自适应分片,以实现工作负载的均衡。
3. 启发式分片算法:为了有效结合Whole-Doc和Per-Doc(逐文档)分片策略,FlashCP设计了一种启发式算法。该算法能够高效地搜索近乎最优的分片方案,克服了分片问题的NP-hard特性。
主要贡献总结:
* 揭示了现有CP框架的局限性,即它们无法保持高计算核效率,并遭受冗余KV张量通信的困扰。
* 提出了FlashCP框架,该框架通过整体优化输入分片和通信流,实现了均衡的工作负载、减少的通信和高计算核效率。
* 实验证明,在不同数据集上,FlashCP相比最先进的CP框架实现了高达1.63倍的加速。
输入数据长度分布不均与打包技术。在LLM训练中,输入文档的长度分布高度倾斜【2, Why does the effective context length of llms fall short?, 2024, arXiv preprint arXiv:2410.18745】【12, DynaPipe: Optimizing Multi-task Training through Dynamic Pipelines, 2024, Proceedings of the Nineteenth European Conference on Computer Systems】【31, Wlb-llm: Workload-balanced 4d parallelism for large language model training, 2025b, 19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25)】。早期的填充(padding)方法会引入冗余的计算、通信和内存开销【23, Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism, 2019, arXiv preprint arXiv:1909.08053】。为了解决这个问题,研究者提出了输入打包(input packing)技术,将多个短文档拼接成一个长序列【35, Analysing the impact ´ of sequence composition on language model pre-training, 2024, arXiv preprint arXiv:2402.13991】【22, Exploring the limits of transfer learning with a unified text-to-text transformer, 2020, Journal of machine learning research】【14, Efficient Sequence Packing without Cross-contamination: Accelerating Large Language Models without Impacting Performance, 2021, arXiv preprint arXiv:2107.02027】【29, Packing analysis: Packing is more appropriate for large models or datasets in supervised fine-tuning, 2024, arXiv preprint arXiv:2410.08081】。在此基础上,文档掩码(Document Mask),也称为文档内因果掩码(intra-document causal masking),被提出来屏蔽跨文档的注意力计算,以确保正确的注意力行为【21, FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention, 2024, https://pytorch. org/blog/flexattention/】【19, NVIDIA NeMo Framework: Sequence Packing, , https://docs.nvidia.com/ nemo-framework/user-guide/latest/ sft_peft/packed_sequence.html】【35, Analysing the impact ´ of sequence composition on language model pre-training, 2024, arXiv preprint arXiv:2402.13991】【15, Enhancing training efficiency using packing with flash attention, 2024, arXiv preprint arXiv:2407.09105】。输入打包和文档掩码的结合已成为大规模、长上下文LLM训练的广泛采用范式,并成功应用于Llama3等工业级模型中【31, Wlb-llm: Workload-balanced 4d parallelism for large language model training, 2025b, 19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25)】【7, Bytescale: Efficient scaling of llm training with a 2048k context length on more than 12,000 gpus, 2025, arXiv preprint arXiv:2502.21231】【5, The Llama 3 Herd of Models, 2024, arXiv preprint arXiv:2407.21783】。
上下文并行的基本原理。上下文并行(CP)通过沿序列长度维度对输入序列进行分区,并将其分布到多个工作节点上,从而缓解长上下文窗口带来的巨大激活内存问题【20, Megatron Core: Context Parallelism, 2023, https: //http://docs.nvidia.com/megatron-core/ developer-guide/latest/api-guide/ context_parallel.html】【9, Loongtrain: Efficient training of long-sequence llms with headcontext parallelism, 2024, arXiv preprint arXiv:2406.18485】【5, The Llama 3 Herd of Models, 2024, arXiv preprint arXiv:2407.21783】【31, Wlb-llm: Workload-balanced 4d parallelism for large language model training, 2025b, 19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25)】。每个工作节点处理一部分token并本地计算注意力,同时通过交换KV张量来获取完整的注意力上下文。图1(b)展示了遵循Llama3 CP实现【5, The Llama 3 Herd of Models, 2024, arXiv preprint arXiv:2407.21783】的一个例子,其中输入序列被分割成2倍CP大小的分片,每个工作节点处理两个分片。这种静态分片策略可能导致工作节点之间严重的工作负载不平衡,凸显了优化文档分片和token分配的必要性。
现有CP框架的不足。Llama3 CP【5, The Llama 3 Herd of Models, 2024, arXiv preprint arXiv:2407.21783】采用粗粒度的输入分片策略,即均匀地分割整个输入序列,这通常导致CP工作节点之间的严重工作负载不平衡。为了解决这个问题,两种更先进的框架被提出:Per-Doc CP【31, Wlb-llm: Workload-balanced 4d parallelism for large language model training, 2025b, 19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25)】和Ring-Attn (Zigzag)【36, ring-flash-attention: Ring attention implementation with flashattention, 2025, https://github.com/zhuzilin/ring-flash-attention】。这两种框架都采用细粒度的、逐文档的分片策略,每个输入文档被分成2倍CP大小的块,第i个块和第(2N-1-i)个块被分配给第i个CP工作节点。如图2所示,这种细粒度的文档分片策略可以在CP工作节点之间实现均衡的工作负载分配。尽管这两种方法都有效地实现了工作负载平衡,但它们存在一些共同的局限性 。
降低的计算核效率。细粒度分片导致每个注意力计算核在更短的文档分片上操作,从而降低了计算核效率【31, Wlb-llm: Workload-balanced 4d parallelism for large language model training, 2025b, 19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25)】。为了证明这一效应,本文通过两种输入模式对计算核执行延迟进行了性能分析:一种是长度为128K的单个文档,另一种是16个8K长度的文档。如图3所示,Per-Doc分片导致了明显更高的注意力延迟,尤其是在由短文档主导的工作负载中。Ring-Attn (Zigzag)的延迟甚至更高,因为它逐块计算注意力,引入了额外的部分结果聚合开销。
冗余通信。尽管这两种框架采用不同的通信方法,但它们都存在冗余通信的问题。Per-Doc CP依赖于集体通信(如AllGather和ReduceScatter),而Ring-Attn (Zigzag)使用点对点(P2P)通信。在这两种情况下,整个KV张量都在所有CP工作节点之间传输。然而,每个工作节点仅需要KV张量的一个子集来计算注意力,这导致了不必要的数据传输,如图2所示。
目标与约束定义。我们的目标是确定一种输入分片和分配策略,该策略能在CP工作节点之间平衡计算负载,同时最小化KV通信开销和计算核效率的下降。我们考虑一个CP大小为N,上下文窗口为C的上下文并行场景。给定一个由n个文档D = [$d_1, d_2, \dots, d_n$]组成的输入序列,其中$d_i$表示第i个文档的长度,这些文档被进一步划分为m个文档分片S = [$s_1, s_2, \dots, s_m$],其中$s_i$是分片长度。每个分片还关联一个前缀长度$p_i$,表示其起始位置之前的token数量。
输入分片分配。对于一个输入文档分片$s_i$,其分配由一个二元变量数组表示:$x_i = [x_{i1}, x_{i2}, \dots, x_{iN}]$,其中如果分片$s_i$被分配给第j个CP工作节点,则$x_{ij} = 1$。每个分片仅被分配给一个工作节点,这由以下约束强制执行:
等量Token约束。为了平衡CP工作节点之间的计算,每个工作节点必须处理相同数量的token,因为非注意力层(如FFN)的成本与token数量成线性关系。此要求由以下约束强制执行:
平衡计算工作负载。由于训练以同步方式进行,整体性能受限于最慢的CP工作节点。因此,目标是在满足公式1和公式2定义的约束条件下,最小化所有工作节点中的最大注意力计算工作负载:
此处,$W_i$是第i个输入分片的注意力计算工作负载,计算方式为$W_i = (2 \cdot p_i + s_i + 1) \times s_i / 2$。
现有通信开销分析。3.1节中的问题形式化缺少了通信开销这一环节。现有的CP实现采用静态通信模式,在所有CP工作节点之间传输完整的KV张量。关键路径上的总通信量可以表示为:
这里,H表示注意力头的数量,D表示头的维度,N表示CP工作节点的数量。4倍的因子是因为在前向和后向传播中,K和V张量都需要通信。由于CP通常应用于跨节点场景,其通信带宽远低于节点内的NVLink,因此通信开销可能成为一个显著的瓶颈。
动态分片感知通信策略。这种方法效率低下,因为它引入了冗余通信。为了减轻CP训练中的通信开销,我们提出了一种动态分片感知通信策略,该策略根据分片方案自适应地确定通信缓冲区的大小,从而避免冗余的KV张量传输。具体来说,如果一个输入文档被分配给单个CP工作节点,则在通信中跳过该文档,因为该CP工作节点已持有该文档的完整KV张量,可以本地执行注意力计算。对于被分片并分发到多个CP工作节点的输入文档,通信大小与所有分片中的最大前缀长度成正比。这是因为每个token的Q张量只与输入文档中KV张量的前缀部分计算注意力。此外,由于每个文档分片可能有不同的前缀长度,这会导致每个输入文档的通信缓冲区中出现零填充。为避免这种情况,我们不单独处理每个文档,而是使用一个单一的连续通信缓冲区,并将所有文档分片前缀部分对应的KV张量紧凑地存入该缓冲区。通过我们的动态分片感知策略,通信大小减少为:
此处,$\hat{S}$表示除每个输入文档的最后一个分片外的所有输入文档分片集合。一个未被进一步分割的完整文档也被视为最后一个分片,因此不包含在$\hat{S}$中。通过应用动态分片感知通信策略,两类输入分片的通信得以减少:(1)完全分配给单个CP工作节点的文档,以及(2)每个输入文档的最后一个分片,从而实现了显著的通信节省。
Whole-Doc分片的核心思想与挑战。为了最大化通信节省,理想情况下我们希望将每个输入文档完整地保留在单个CP工作节点上,从而消除KV通信并通过单核注意力计算来保持计算核效率。如图4(1)所示,在两个等长文档的理想情况下,最优策略是保持每个文档完整并将其分配给不同的CP工作节点,这样既实现了计算平衡,又消除了通信。然而在实践中,文档长度各不相同,这使得在均匀分配工作负载的同时保持文档完整性变得困难。在图4(2)中,当输入包含不同长度的文档时,必须分割较长的文档以满足公式2中的等量token约束。虽然这种分片方案产生的通信开销相对较低(与$\Delta l$成正比),但它引入了显著的工作负载不平衡。为了应对这一挑战,我们提出了一种通信感知的Whole-Doc分片方法,该方法自适应地对文档进行分片以平衡工作负载,同时最小化通信开销。具体来说,我们不只是简单地划分最长的文档来实现CP工作节点间的等量token分配,而是提倡自适应地划分多个文档。如图4(2)右侧所示,Doc-2和Doc-3都被部分划分,以实现均衡的工作负载和每个工作节点等量的token计数。此外,由于只需要交换文档的下半部分,通信量受限于$\max(\Delta l_1, \Delta l_2)$,显著减少了总体通信开销。这种方法被称为Whole-Doc分片,因为它旨在将每个文档作为一个整体来保留,仅在必要时应用自适应分片以平衡工作负载分布。
结合Per-Doc和Whole-Doc分片。尽管Whole-Doc分片减少了通信开销并有助于实现CP工作节点间的工作负载平衡,但它无法有效处理所有输入文档的组合。在某些情况下,例如当输入序列包含一个占绝大多数token的极长文档时,由于该长文档与其余短文档之间的显著差异,Whole-Doc分片无法实现均衡的工作负载分布。为了解决这个问题,我们设计了一种结合了Per-Doc和Whole-Doc分片的混合策略。具体来说,我们根据每个文档的长度为其选择合适的分片方法。对于导致工作负载不平衡的极长文档,应用Per-Doc分片以均匀分配计算量,同时由于其token数量充足,仍能保持较高的注意力计算核效率。相比之下,对于较短的文档,当文档长度差异适中时,则使用Whole-Doc分片来减少通信开销并实现均衡的工作负载。
基于ILP的分片。基于第3.1节中的问题形式化,我们可以通过将通信开销项(公式5)纳入目标函数,将分片任务建模为一个整数线性规划(ILP)问题。这种形式化允许我们使用ILP求解器为给定的分片粒度获得最优的输入分片和分配方案。然而,解决ILP的计算成本对于实际使用来说高得令人望而却步,这促使我们设计一种更高效的启发式分片算法来搜索近乎最优的解。
FlashCP启发式分片算法。为了高效地搜索近乎最优的分片方案,我们提出了一种贪心启发式分片算法。算法的细节在算法1中给出。该算法以文档序列D = [$d_1, d_2, \dots, d_n$]和目标不平衡率R = $\frac{\text{max workload}}{\text{avg workload}}$作为输入。这里,max workload和avg workload分别是所有CP工作节点中的最大和平均注意力计算工作负载。算法首先按长度降序对文档进行排序,然后通过将每个文档分配给工作负载最小的工作节点来迭代构建一个临时方案tmpp(第5-9行)。在此过程中,算法将整个文档分配给特定的CP工作节点而不进行分片,这有助于最大化通信节省并保持计算核效率。之后,算法检查分配给每个工作节点的token数量。如果token数量不相等,它会应用Whole-Doc分片来平衡token分布和注意力计算工作负载(第10-16行)。之后,tmpp成为一个满足等量token约束(公式2)的有效分片方案,算法计算当前分片方案tmpp的不平衡率Cur R(第18行)。如果tmpp的不平衡率大于R,则表明可能存在一些极长的序列阻碍了负载均衡。为了解决这个问题,算法从输入序列中移除最长的文档,并对其应用Per-Doc分片(第19-23行)。剩余的文档用于后续的迭代。这个过程重复进行,直到实现的不平衡率Cur R小于目标比率R。然后,算法返回tmpp作为最终的Whole-Doc分片方案,以及分配给Per-Doc分片的一组文档。
算法 1 FlashCP的启发式分片算法
输入: 输入 D = [d1, d2, · · · , dn], 目标不平衡率 R
输出: Per-Doc分片方案 Per Doc P 和 Whole-Doc分片方案 Whole Doc P
1: 按长度降序对D进行排序。
2: 初始化空的Per-Doc分片方案 Per Doc P
3: 初始化当前不平衡率 Cur R = Inf
4: while Cur R > R do
5: 初始化空的临时Whole-Doc分片方案 tmp P
6: for doc d in D do
7: // 将d添加到负载最轻的工作节点:
8: tmp P.Min Worker Add(d)
9: end for
10: // 确保满足公式2:
11: while tmp P 的token不相等 do
12: // 从token最多的工作节点中弹出文档:
13: docs = tmp P.Max Token Worker Pop()
14: // 应用Whole-Doc分片:
15: tmp P.Whole Doc Shard and Add(docs)
16: end while
17: // 更新当前不平衡率:
18: Cur R = T.Compute Imba Ratio(tmp p)
19: if Cur R > R then
20: // 弹出最长的文档并应用Per-Doc分片:
21: d = D.Pop Front()
22: P er Doc P.Add(d)
23: end if
24: end while
25: Whole Doc P = tmp P
26: 返回 Per Doc P 和 Whole Doc P
基线方法:
数据集:
评估方法:从每个数据集中随机抽样10万个输入序列进行评估。每个输入序列由多个文档组成,如果总长度超过上下文窗口大小,则截断最后一个文档。
我们在三个基准数据集上,使用两种模型配置(16和32个注意力头,头维度128)和128K的上下文窗口,评估了FlashCP及所有基线的CP训练和推理性能。CP大小设置为4和8。
Table 2. ILP求解器与启发式算法的比较。
随着LLM上下文窗口的持续增长【26, Gemini: a family of highly capable multimodal models, 2023, arXiv preprint arXiv:2312.11805】【27, Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context, 2024, arXiv preprint arXiv:2403.05530】【24, Meta llama 4: The future of multimodal ai, 2025, Available at SSRN 5208228】【37, Skyladder: Better and faster pretraining via context window scheduling, 2025, arXiv preprint arXiv:2503.15450】,训练越来越受到注意力层激活内存快速增长的限制【13, Reducing Activation Recomputation in Large Transformer Models, 2023, Proceedings of Machine Learning and Systems】【30, Lemo: Enabling less token involvement for more context fine-tuning, 2025a, arXiv preprint arXiv:2501.09767】。CP最近成为一种有效的策略,它沿着序列维度划分输入序列和激活,将注意力计算分布到多个GPU上【20, Megatron Core: Context Parallelism, 2023, https: //http://docs.nvidia.com/megatron-core/ developer-guide/latest/api-guide/ context_parallel.html】【9, Loongtrain: Efficient training of long-sequence llms with headcontext parallelism, 2024, arXiv preprint arXiv:2406.18485】【5, The Llama 3 Herd of Models, 2024, arXiv preprint arXiv:2407.21783】【31, Wlb-llm: Workload-balanced 4d parallelism for large language model training, 2025b, 19th USENIX Symposium on Operating Systems Design and Implementation (OSDI 25)】。早期的CP方法采用基于环的通信来交换工作节点间的KV张量【16, Ring Attention with Blockwise Transformers for Near-Infinite Context, 2023, arXiv preprint arXiv:2310.01889】,这使得计算和通信能够部分重叠,但难以支持输入打包所需的复杂注意力掩码。为了解决这个限制,引入了zigzag风格的分片【36, ring-flash-attention: Ring attention implementation with flashattention, 2025, https://github.com/zhuzilin/ring-flash-attention】。然而,由于逐块执行注意力,它仍然存在计算核效率低下的问题。更新的方法利用了集体通信原语,如AllGather 【20, Megatron Core: Context Parallelism, 2023, https: //http://docs.nvidia.com/megatron-core/ developer-guide/latest/api-guide/ context_parallel.html】【5, The Llama 3 Herd of Models, 2024, arXiv preprint arXiv:2407.21783】和AlltoAll【11, DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models, 2023, arXiv preprint arXiv:2309.14509】,通过为每个工作节点提供全局KV视图来支持带有文档掩码的高效注意力计算。然而,这些方法通信整个KV张量,引入了显著的冗余通信。
本文介绍了FlashCP,一个为大规模、长上下文LLM训练设计的负载均衡且通信高效的上下文并行框架。具体而言,FlashCP提出了一个分片感知的通信机制,以最小化不必要的通信;以及一个新颖的Whole-Doc分片策略,旨在最大化通信节省的同时保持CP工作节点间的负载均衡。此外,FlashCP还引入了一个启发式分片算法,以高效地搜索近乎最优的分片方案。大量的实验表明,在各种数据集上,FlashCP相较于当前最先进的CP框架可提供高达1.63倍的加速。