ByteScale: Efficient Scaling of LLM Training with a 2048K Context Length on More Than 12,000 GPUs

作者/机构: Hao Ge (Peking University), Junda Feng (ByteDance Seed), Qi Huang (ByteDance Seed), Fangcheng Fu (Peking University), Xiaonan Nie (ByteDance Seed), Lei Zuo (ByteDance Seed), Haibin Lin (ByteDance Seed), Bin Cui (Peking University), Xin Liu (ByteDance Seed)

A1 主要贡献

本文针对大规模语言模型(LLM)长短序列混合训练场景,提出了一个名为 ByteScale 的高效、灵活且可扩展的训练框架。现有框架通常采用静态的通信网格(如二维网格)来组织设备,将数据并行(DP)和上下文并行(CP)视为正交的技术,但这导致在处理长度可变的序列时,出现冗余通信和计算不平衡的问题,从而降低了训练效率。

核心问题:
1. 冗余通信: 静态并行策略强制所有序列(无论长短)都使用为最长序列配置的上下文并行(CP)组进行分区和通信,即使短序列并不需要跨设备分区,也产生了不必要的通信开销。
2. 计算不平衡: 尽管通过打包(packing)技术可以在设备间均匀分配令牌数量,但由于自注意力机制的计算复杂度为 $O(L^2)$($L$为序列长度),不同打包序列的实际计算负载(FLOPs)差异巨大,导致设备间出现计算不平衡,产生空闲等待时间(气泡)。

研究目标与创新点:
为了解决上述挑战,ByteScale 提出了以下核心贡献:

A3 背景知识与关键洞察

2.1 Transformer 与大语言模型

Transformer 架构【40, Attention is All you Need, 2017, NeurIPS 2017】是当前大语言模型(LLM)的主流基础架构。它由一系列 Transformer 层堆叠而成,每个层包含一个注意力模块和一个前馈网络(FFN)模块。如图 1 所示,自注意力机制需要在整个序列的所有令牌间进行计算以捕获上下文信息,而其他操作(如归一化、线性投影和激活函数)则是逐令牌计算的,每个令牌可以独立处理。

图 1. Transformer 层的架构
图 1. Transformer 层的架构

2.2 分布式 LLM 训练

2.3 填充与打包

为了在静态并行策略中支持可变长度序列,需要使用填充(padding)和打包(packing)技术。如图 2 所示,填充将同一批次中的序列补齐到相同长度,但这会造成计算浪费。打包【22, Efficient sequence packing without cross-contamination: Accelerating large language models without impacting performance, 2021】则将多个序列拼接成一个长序列,不使用填充令牌,并采用一种特殊的分段注意力掩码(segmented attention mask)来确保每个序列被独立处理。

图 2. 序列填充与打包
图 2. 序列填充与打包

2.4 长上下文训练

自注意力的时间和内存复杂度均为 $O(L^2)$,这成为扩展上下文长度的瓶颈。Flash Attention【7, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning, 2023】【8, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022, NeurIPS 2022】通过优化内存 I/O 和使用分块技术,将内存复杂度从 $O(L^2)$ 降至 $O(L)$,但时间复杂度仍为 $O(L^2)$。上下文并行(Context Parallelism, CP)【4, Striped Attention: Faster Ring Attention for Causal Transformers, 2023】【23, LightSeq: Sequence Level Parallelism for Distributed Training of Long Context Transformers, 2023】【25, Ring Attention with Blockwise Transformers for Near-Infinite Context, 2023】【31, NVIDIA: Context Parallelism, 2024】进一步将序列划分到 $C$ 个设备上,将内存从 $O(L)$ 降至 $O(L/C)$。CP 沿序列维度对 QKV 进行分片,跨令牌操作需要设备间通过环形点对点通信交换 KV 切片,并与计算重叠。该技术也适用于打包序列,如图 2(c) 和 3(a) 所示,每个子序列也必须被划分到所有 CP rank 上。

图 3. 带打包的上下文并行
图 3. 带打包的上下文并行

3.1 数据异构性

3.2 冗余通信

3.3 计算不平衡

A2 方法细节

4. ByteScale 概览

为了应对上述挑战,我们提出了 ByteScale。如图 7 所示,它包含三个主要组件:
1. Profiler (分析器): 用于分析环境、模型配置、数据分布,并为其他组件构建成本模型。
2. Communication Optimizer (通信优化器): 通过数据感知分片、动态通信和选择性卸载,来提高长短序列的通信效率。
3. Balance Scheduler (平衡调度器): 通过并行感知的数据分配,来解决计算不平衡问题。

图 7. ByteScale 概览
图 7. ByteScale 概览

5. 通信优化器

本节描述 ByteScale 如何优化通信开销。首先,通过动态序列分片和通信减少短序列的冗余通信。其次,通过选择性卸载进一步压缩长序列的通信成本。

5.1 数据感知分片与通信

5.2 数据感知的选择性卸载

# 代码清单 1. act_ctx 的用法
from activation_offload import act_ctx

# 启用激活值卸载
with act_ctx(offload_ratio=ratio):
    # Transformer 前向传播
    output = model(input)
# 反向传播
loss.backward()
图 10. 逐层激活值卸载
图 10. 逐层激活值卸载

5.3 整体流程

ByteScale 的整体流程如算法 1 所示。简而言之,算法遍历全局批次中的每个序列 $S_i$。对于长序列,它推导出卸载比例 $\alpha$ 和所需的 rank 数量 $C(S_i)$(1-6行)。对于短序列,它将它们打包以填满每个 rank 的容量 $T$(7-9行)。处理后的序列被分配给 $D_{hdp}$ 个 rank,算法返回每个 rank 的微批次和 offload_ratio 用于执行(10-12行)。

算法 1: 朴素 HDP 解决方案
算法 1: 朴素 HDP 解决方案

6. 平衡调度器

本节介绍平衡调度器如何解决 DP 和 PP 的不平衡问题。通过精心设计数据分配(替代算法 1 的第 10 行),它在保持最小通信开销的同时缓解了这些不平衡。

6.1 重新定义微批次

梯度累积要求不同 DP rank 执行相同数量的微批次,这是基于所有微批次计算负载相同的假设。然而,实际执行时间差异很大。在 ByteScale 中,我们重新定义了一个更灵活的策略,允许不同的 HDP rank 处理不同数量的微批次(大小相同但工作负载不同),以缓解不平衡问题。如图 13 所示,这使得所有 rank 能在同一时间完成计算。更重要的是,该策略不影响模型收敛,因为我们最终计算的是全局批次中所有令牌的梯度总和,保证了数学等价性。

图 13. 平衡策略
图 13. 平衡策略

6.2 解决 PP 不平衡

6.3 解决 DP 不平衡

6.4 平衡策略

算法 2 描述了平衡策略。
1. 首先,按长度降序对全局批次 B 中的序列进行排序。然后将这些序列划分为 FLOPs 总和近似相等的桶(buckets),因此平均长度较长的桶包含的序列较少(3-5行)。
2. 其次,确定哪些 rank 的执行时间较短,以便后续分配(7-9行)。
3. 第三,如果使用 DP-Balance 策略,则从同一个桶中选择序列;如果使用 PP-Balance 策略,则从所有桶中顺序选择序列。实际上,执行时间较短的 rank 会被分配更多序列(12-15行)。
4. 最后,重复第二和第三步,直到所有桶都为空。

算法 2: HDP 的平衡策略
算法 2: HDP 的平衡策略

7. 实现细节

ByteScale 基于 Python, C++ 和 CUDA 实现,代码约 16K 行,并已集成到高性能 LLM 训练框架 MegaScale【18, MegaScale: scaling large language model training to more than 10,000 GPUs, 2024, NSDI’24】中。

A4 实验环境

A4 实验结果

8.2 端到端评估

图 17. 端到端评估(单位:令牌/秒)
图 17. 端到端评估(单位:令牌/秒)

8.3 案例研究

图 18. 案例研究
图 18. 案例研究

8.4 消融研究

图 19. 网络流量和张量核心利用率
图 19. 网络流量和张量核心利用率

图 20. 消融研究
图 20. 消融研究

图 21. 激活值卸载的有效性
图 21. 激活值卸载的有效性

A5 结论

本文提出了 ByteScale,一个为大规模长短序列混合训练设计的高效、灵活且可扩展的分布式 LLM 训练框架。通过开发的通信优化器消除了冗余通信,并通过平衡调度器缓解了计算不平衡。在超过 12,000 个 GPU 的生产集群上,对从 7B 到 141B 的模型和从 256K 到 2M 的上下文长度进行了评估,实验结果显示 ByteScale 相较于 MegaScale 实现了高达 7.89 倍的性能提升。