Tailing Yuan, Yuliang Liu, Xucheng Ye, Shenglong Zhang, Jianchao Tan, Bin Chen, Chengru Song, and Di Zhang, Kuaishou Technology
本文针对大语言模型(LLM)训练中计算资源和内存消耗巨大的问题,提出了无损加速训练的新方法。现有的工作主要集中在优化激活策略(如卸载和重计算)以及探索各种并行训练选项上,但在平衡计算与内存利用率,以及高效搜索最佳并行配置方面仍有改进空间。
本文的主要贡献如下:
1. Pipeline-Parallel-Aware Offloading(流水线并行感知卸载):提出了一种高效的激活重物化策略,该策略遵循流水线并行模式,最大化利用主机(Host)内存来存储激活值,且开销可忽略不计。
2. Compute-Memory Balanced Checkpointing(计算-内存平衡检查点):提出了一种在激活内存和计算效率之间寻求实际平衡的策略,通过推导内存成本和计算成本的帕累托前沿(Pareto frontier),实现最优的检查点选择。
3. 高效的混合并行参数搜索方法:提出了一种极高效的搜索方法,通过测量集群相关的原语信息(cluster-related primitive information)和模型相关的原语信息(model-related primitive information)构建性能模型,从而在考虑卸载和检查点策略的同时,详尽搜索最优的混合并行参数(t, c, p, d)。
4. 实验验证:通过在具有不同模型大小和上下文窗口大小的公共基准上进行广泛实验,证明了该方法的有效性。例如,对于上下文窗口大小为 32,768 的 175B Llama 类模型,在 256 个 NVIDIA H800 GPU 上,该方法将模型 FLOPs 利用率(MFU)从 32.3% 显著提高到 42.7%。
3. 预备知识
模型定义:大型语言模型通常由嵌入层、L 个 Transformer 层和头部层(head layer)组成。Transformer 层的隐藏状态维度记为 $h$。前馈网络中激活函数的输出维度称为“中间大小”,记为 $H$。表 1 列出了定义模型的更多符号。
混合并行组:在混合并行训练期间,集群中的 GPU 被划分为张量并行组(大小 $t$)、上下文并行组(大小 $c$)、流水线并行组(大小 $p$)和数据并行组(大小 $d$)的笛卡尔积,因此 GPU 总数 $\#GPUs = tcpd$。
内存占用:设备上的内存消耗可归类为以下四部分(忽略小两个数量级以上的项):
模型权重和梯度:
每个 Transformer 层的权重大小为:
在混合并行中,L 个 Transformer 层被交错流水线并行分割。每个设备持有 $v$ 个阶段,每个阶段包含 $l$ 个 Transformer 层,因此 $L = pvl$。权重和梯度沿张量并行组分布。每个设备上的大小为:
其中 $\mathbb{1}$ 是指示函数,$r_{pp}$ 是流水线并行 rank。系数 $6 = 2 + 4$ 包含 BF16 权重和 FP32 梯度。
优化器状态和主权重:
优化器在数据并行组中进一步分布,每个设备上的大小为:
其中系数 $12 = 4 + 4 + 4$ 是 FP32 主权重和两个 FP32 Adam 状态的总和。
最大存活激活大小:
激活分布在张量并行组和上下文并行组中。本文中,只要 $t \ge 2$,总是启用序列并行以确保激活完全沿张量并行组分布;采用 FlashAttention [6, FlashAttention: Fast and memory-efficient exact attention with IO-awareness, 2022] 减少注意力的内存开销。“激活块(activation block)”指一个流水线阶段(包含 $l$ 个 Transformer 层)为一个微批次(micro-batch)存储的所有激活。激活块的大小为:
在交错流水线并行中,第一个反向步骤之前的正向步骤数为 $(vp + p - 2r_{pp} - 1)$。因此,每个设备上的最大存活激活大小为:
其他缓冲区和开销:
这部分大小依赖于实现,包括 CUDA 上下文开销、内存管理碎片、NCCL 缓冲区、cuBLAS 工作区和临时变量。
为了说明问题,表 3 展示了不同并行配置的估计内存需求和计算时间。迭代时间基于 [19, Efficient large-scale language model training on GPU clusters using Megatron-LM, 2021] 的公式 (1) 近似:
4. 动机
4.1 激活瓶颈
训练大语言模型(尤其是长上下文窗口模型)时,GPU 内存面临巨大挑战。例如,训练具有 32k 上下文窗口的 Llama-175B,无论采用何种混合并行参数(假设 $tc \le 8$),第一个 rank 上至少需要 $M_a = 171.5$ GB 的激活内存。解决此问题的两种直接方法有显著副作用:全检查点(Full checkpointing)导致 1/3 的额外计算成本;增加张量或上下文并行大小会导致大量通信开销和计算强度降低。
4.2 混合并行调优的挑战
问题 1(基本问题):给定模型、序列长度 $s$、全局批量大小 $B$ 和节点数,搜索一组混合并行参数 $(t, c, p, l, ckpt)$ 以最小化每次迭代的时间 $T$。
问题 2(吞吐量最大化问题):给定模型、序列长度 $s$、满足条件的全局批量大小范围 $B \in [B_{min}, B_{max}]$ 和最大节点数,搜索一组混合并行参数 $(t, c, p, l, ckpt)$ 以最大化吞吐量 $Bs/T$。
即使利用先验知识(如固定微批次 $b=1$,避免跨节点张量并行,避免非 GQA 模型的跨节点上下文并行)来缩小搜索空间,空间仍然非常巨大(如表 4 所示)。针对问题 2 的详尽搜索将反复调用问题 1 的过程。
5.1 流水线并行感知卸载 (Pipeline-Parallelism-Aware Offloading)
调度原则:如图 3 所示,利用交错流水线并行(interleaved pipeline parallelism)[19] 的流程,根据两个原则设计卸载方案:
1. 尽早卸载:在每个流水线阶段正向传播(forward)结束后立即开始卸载。
2. 尽早重载:在上一个流水线阶段反向传播(backward)开始时开始重载(Reloading)。
实现细节:
* 调度粒度:卸载和重载的调度粒度为“流水线阶段”。在大型语言模型中,所有流水线阶段具有相同的计算时间和激活大小(以及相同的传输时间)。因此,完成卸载和计算的时间由两者中较慢的一个决定,较快的一个可以被完全重叠(overlap)。
* 事件序列化:为了复用正在传输中的 GPU 内存分配,使用 cudaStreamWaitEvent 使每次卸载等待上一次卸载完成。这确保了最多只有一个流水线阶段的激活正在进行卸载。
* Ping-pong 重载:在重载中使用两个缓冲区:一个用作重载的目标,另一个供当前反向步骤使用。在下一个反向步骤中,两个缓冲区的角色互换。激活张量是从上一个重载的缓冲区中原地(in-place)构造的。
* 带宽利用率增强:为了实现 GPU 和主机之间的最高带宽,为每个进程绑定非一致内存访问(NUMA)节点,并为 CPU 缓冲区使用锁页内存(page-locked memory)。
内存大小分析:
使用卸载比率 $\alpha (0 \le \alpha \le 1)$ 控制卸载到主机内存的激活量。
在第一个流水线并行 rank 上,卸载的激活块最大数量为 $(vp + p - 3)$,一个激活块正在卸载,一个激活块由当前正向步骤生成,两个缓冲区用于重载。
峰值 GPU 内存使用量为:
峰值主机内存使用量为:
选择卸载比率 $\alpha$ 时应尽可能小,原因有二:主机与设备间的内存拷贝可能会因资源竞争而减慢计算速度;卸载可能无法完全被计算重叠。首先利用最大可用 GPU 内存大小通过方程 7 求解 $\alpha$,然后计算 $M_{host}$ 并检查是否超出主机内存。
重叠分析:
随着卸载比率 $\alpha$ 的增加,卸载可能无法完全被计算重叠。流水线调度包含 3 个阶段:仅包含正向的热身阶段(warm-up)、包含成对正向和反向的稳定阶段(steady)、以及仅包含反向的冷却阶段(cooldown)。
非重叠的卸载/重载首先发生在热身阶段。原因是正向传播比反向传播快约 2 倍,因此稳定阶段有 3 倍的时间进行卸载和重载,而冷却阶段有 2 倍的时间进行重载。
热身阶段的关键路径中有 $vp - 1$ 个步骤。第一个正向步骤没有重叠,第 2 到第 $p$ 个步骤重叠 $T_{embF} + lT_F$(其中 $T_{embF}$ 是嵌入层的正向时间)。其他 $vp - p - 1$ 个步骤重叠 $lT_F$。
热身阶段的开销为:
其中 $BW_{DtoH}$ 是所有 GPU 并行进行内存拷贝时的设备到主机内存拷贝带宽。稳定阶段和冷却阶段的开销计算类似(见附录 B.3)。虽然利用主机内存会有一定开销,但通过下一节的方法可以减少激活大小。
5.2 计算-内存平衡检查点 (Compute-Memory Balanced Checkpointing)
概述:
现有的检查点策略(如 [1, 2, 11, 20])虽然能优化内存,但往往针对每个算子的临时张量进行精细评估,导致对并行配置变化敏感,搜索开销大。在混合并行场景中,这会导致难以接受的求解开销。
本文的方法关注存储的激活大小,因为动态的临时内存最多由 $l$ 层生成,而存储的激活由最多 $(vp + p - 2)l$ 层生成。本文通过枚举存储的激活集合,推导出内存成本和计算成本的帕累托前沿,从中选择一个计算-内存平衡的解决方案。
帕累托前沿 (Pareto Frontier):
在一个 Transformer 层中,每个子层存储的激活大小如表 5 所示。
首先,对于每个激活张量,找到重建它的最小计算成本。与以前的工作不同,流水线并行中可以忽略临时内存,因此逐层重建激活。
图 4(a) 展示了重建 Attention 输入的例子,需要重计算两层。图 4(b) 展示了更复杂的情况,重建 SiLU 输入需要重计算 Linear,且 Linear 输出也被 Mul 使用,因此重计算 Linear 层重建了两个激活。
重建每个激活张量的计算成本如表 6 所示。例如,重建 activation#2(即表 5 中第二层存储的激活)是高效的,因为它通过仅引入 RMSNorm#1 的重计算(成本仅 0.061 ms)节省了 $2bsh/(tc)$ 的内存。相比之下,重建 activation#7 虽然节省相同内存,但引入了 1.018 ms 的重计算,效率较低。
通过枚举存储激活的集合,可以得到内存成本和计算成本的帕累托前沿,如图 5 所示。本方法的粒度是每个 Transformer 层,不考虑为了重建某层输入而重计算上一层最后两层的情况,也不考虑将一层拆分的情况。
计算-内存平衡解决方案:
该方法使用帕累托前沿的拐点(inflection point)。存储的激活如表 6 所示。在我们的方法中,所有计算密集型层(Linear 和 Attention)都不进行重计算。被重计算的层集合包括 RMSNorm#1, RMSNorm#7, SiLU#9, 和 Mul#10。总重计算时间仅为正向和反向时间的 1.5%。
使用计算-内存平衡检查点后,每个 Transformer 层存储的激活大小变为:
与无检查点相比,它为 Llama-175B 和 Llama-65B 节省了 39% 的激活内存,为 Llama2-70B 节省了 44%。
有趣的是,全检查点(Full checkpointing)不在帕累托前沿上,因为它总是重计算所有层。如图 6 所示,SuperNeurons [31] 存储 Linear 的输出,因此比本文方法重计算了更多的层。
5.3 混合并行参数调优
5.3.1 调优算法
为了减少实验总数,测量一些原语信息并构建性能模型。优化问题的公式为:
其中 $M_{gpu}^{thresh}$ 和 $M_{host}^{thresh}$ 分别是设备 GPU 和主机内存的阈值。
5.3.2 性能建模
为了准确建模,必须识别表 7 所示的关键性能基准。
1. 流水线热身阶段 (Warm-up phase):
从第一个流水线阶段(在第一个 rank 上)的第一个微批次正向开始,到最后一个流水线阶段(在最后一个 rank 上)的第一个微批次正向开始。
其中 $T_{p2p}$ 是流水线并行点对点通信的时间,需针对每个通信消息大小 $2bsh/(tc)$ 进行测量。
流水线稳定阶段 (Steady phase):
紧接热身阶段,结束于最后一个流水线阶段上最后一个微批次反向的尾部。
流水线冷却阶段 (Cooldown phase):
紧接稳定阶段,结束于第一个流水线阶段上最后一个微批次反向的尾部。
优化器通信和计算:
包括梯度的 reduce-scatter、优化器步骤和模型权重的 all-gather。传输时间由带宽决定,忽略延迟。
其中 $BW_{opt}$ 是优化器通信组的算法带宽,$\omega_{adam}$ 是参数数量与 Adam 优化器执行时间的比率(在 NVIDIA H800 80GB 上为 53.4 GHz)。
卸载开销和检查点开销:
总卸载开销记为 $T_{offload}$(根据 5.1.3 节计算)。在一个 Transformer 层中重计算选定子层的时间记为 $T_{ckpt}$。如果使用计算-内存平衡检查点,$T_{ckpt}$ 加到 $T_B$ 中。若使用全检查点,则将 $T_{ckpt}$ 替换为 $T_F$。
由重叠引起的计算减速:
关键执行路径上与计算重叠的流水线并行点对点通信数量为 $(4mv - 2m + 2p - 2)$。计算减速引起的额外时间被认为与重叠通信时间成正比,比率为 $\beta_{p2p}$。与计算重叠的卸载数量为 $(mv + p - 2)$,重载亦然。由卸载引起的计算减速时间与卸载大小成正比,比率为 $\beta_{offload}$。
在硬件上测得 $\beta_{p2p} = 0.05$,$\beta_{offload} = 0.0016$ sec/GB。
因此,总时间的估计公式为:
5.3.3 参数搜索
在搜索之前,进行以下原语测量:
1. 对于每个 $(t, c)$ 组合,记录模型层计算时间(可在层数较少的模型上进行以节省时间)。
2. 对选定子层的每个输入形状评估内存平衡检查点的重计算时间 $T_{ckpt}$。
3. 测量每个 $2bsh/(tc)$ 配置的点对点通信时间 $T_{p2p}$。
4. 评估每个 $(t, cd)$ 组合的优化器通信速度。
5. 内存拷贝带宽和参数 $\omega_{adam}, \beta_{opt}, \beta_{offload}$ 仅测量一次。
获取原语信息后,详尽枚举所有 $(t, c, p, l, ckpt)$ 组合,验证其有效性及内存约束,并使用方程 17 估计 $T_{model}$。问题 1 可以在 0.001 秒内解决。
硬件配置:
软件配置:
训练设置:BF16 精度,FP32 梯度累积。Adam 优化器(FP32 状态)。始终启用交错流水线并行。微批次大小固定为 1。
卸载和检查点效果:
图 7 展示了 GPU 内存使用与卸载比率的关系。例如,运行 Llama-65B (context window=8192) 时,若无卸载会导致 OOM。使用 $\alpha \ge 0.3$ 的卸载技术可以训练该模型。使用检查点技术也能训练。随着卸载比率增加,GPU 内存大小减少。
性能模型准确性:
图 8 评估了性能模型的准确性。结果表明:
* 模型对不同的 $t$ 和 $c$ 是准确的 (Fig 8 a,b,c)。
* 对不同的 $p$ 和 $l$ 是鲁棒的 (Fig 8 a,e,f)。
* 对三种检查点方法(无、平衡、全)均保持正确性 (Fig 8 a,g,h)。
* 能适应不同的全局批量大小 (Fig 8 a,d)。
* 在所有情况下,测量时间与 $T_{model}$ 的差异不超过 2.0%。
端到端性能调优:
表 8 比较了基线系统和本文系统在三个模型和不同序列长度下的最佳性能。
* 本文方法通过卸载和平衡检查点的帮助,有更多空间权衡并行配置,从而获得显著性能提升。
* 例如,对于 Llama-175B (32k context),在 256 个 H800 GPU 上,本文方法将 MFU 从 32.3% 提升至 42.7%。
* 观察:对于 GQA(分组查询注意力)模型,上下文并行优于张量并行,因为其通信量更小。但跨节点的上下文并行开销显著,而卸载的开销远小于跨节点上下文并行。
最佳扩展 (Optimal Scaling):
图 9 展示了最佳扩展方法与数据并行扩展方法的对比。
* 最佳扩展算法会为每个节点数量调整参数 $(B, t, c, p, l, ckpt)$。
* 结果显示,最佳扩展方法在许多节点数量下优于数据并行扩展。例如,在 24 个节点上训练 Llama-65B (s=4096),最佳扩展达到 $1.80 \times 10^5$ TPS,而数据并行扩展仅能利用 20 个节点达到 $1.48 \times 10^5$ TPS。
本文提出了两种激活重物化方法:流水线并行感知卸载(最大化主机内存利用率)和计算-内存平衡检查点(寻求激活内存与计算效率的平衡)。为了优化巨大的并行参数搜索空间,提出了一种基于少样本原语信息测量的性能模型,能够高效地详尽搜索最佳参数组合。
局限性与未来工作:计算-内存平衡检查点方法支持的最大序列长度小于全检查点方法。未来将探索更多优化策略,并考虑 GPU 温度对时间测量精度的影响。
B 实现细节
B.1 软件版本与差异
本文代码库 fork 自 2023 年 6 月 9 日的 Megatron-LM (commit db71a33),并在其之上实现了卸载和检查点。表 9 展示了本文代码与最新 Megatron-LM 的主要区别。
B.2 激活的详细计数
表 10 展示了第一个流水线并行 rank 上 $p=4, v=2$ 的调度示例。激活块由 (微批次ID, 流水线阶段ID) 表示。统计数据 "# living act." 是唯一存储的激活块数量。在卸载模式下,分为 "@GPU" 和 "@host"。
B.3 未重叠的卸载
稳定阶段的开销为:
冷却阶段的开销为:
C 额外评估
C.1 额外的卸载和检查点结果
图 10 在两个额外的基准模型上进一步验证了所提出的方法,观察到与图 7 类似的内存开销模式。
C.2 训练损失 (Training Loss)
为了验证技术是否影响收敛,从头训练了一个 Llama2-70B 模型(context window=4096, global batch=1024)。参数设置为 $t=2, c=2, p=8, l=2, ckpt=ours, \alpha=0.5$。如图 11 所示,本文方法的训练损失与最新 Megatron-LM 一致,证明加速技术未损害模型性能。