MeshSlice: Efficient 2D Tensor Parallelism for Distributed DNN Training

A1 主要贡献

本文旨在解决大规模DNN模型分布式训练中张量并行(TP)的通信瓶颈问题。现有的1D TP因通信成本高而可扩展性有限,而2D TP虽然能通过将矩阵分片到2D加速器网格中来减少通信,但其核心的通用矩阵乘法(GeMM)算法存在效率问题。具体来说,Cannon算法通信流量大;SUMMA算法同步开销高;而使用集体通信操作的2D GeMM无法将通信与计算重叠。此外,优化2D TP的众多参数(如数据流、网格形状、分片方式)非常困难,通常需要专家手动配置。

为应对这些挑战,本文做出了以下核心贡献:

  1. 提出新颖的MeshSlice算法:这是一种为分布式DNN训练中的2D TP设计的高效2D GeM M算法。MeshSlice通过将AllGather (AG) / ReduceScatter (RdS) 等集体通信操作切分为多个部分的集体操作,从而实现了通信与计算的重叠。这种方法有效隐藏了大部分通信延迟,解决了现有算法无法在行列两个维度上同时实现重叠的问题。
  2. 开发MeshSlice LLM自动调优器(Autotuner):该工具能够自动为大型语言模型(LLM)的训练找到最优的2D TP配置。它首先选择一个高效的2D GeMM数据流,然后利用分析性成本模型协同优化网格形状和通信粒度,从而替代了繁琐的人工调优过程。
  3. 全面的评估与实现:通过模拟训练GPT-3和Megatron-NLG模型的TPUv4集群,本文验证了MeshSlice的性能。结果显示,MeshSlice在高达256路的2D TP中仍保持高效率。在一个256个TPU的集群中,MeshSlice训练GPT-3和Megatron-NLG模型的速度分别比现有最先进的算法快12.0%和23.4%。此外,本文还在真实的Google TPUv4集群上实现了MeshSlice,验证了其切片操作的开销很小,且自动调优器的成本模型能准确估算通信和计算成本。

A3 背景知识/关键Observation/设计原则

2.1 分布式训练方法

2.2 2D张量并行

2.3 2D GeMM 算法

2.3.1 通用方面

2.3.2 Cannon算法

2.3.3 SUMMA 算法

2.3.4 Collective 2D GeMM

A2 方法细节

本文为2D TP做出了两项贡献。首先,提出了一种新的2D GeMM算法,解决了现有2D GeMM算法的局限性。其次,设计了一个LLM自动调优器,为LLM训练找到一个高效的2D TP配置。该LLM自动调优器优化了数据流、网格形状和通信粒度的配置。

我们提出的2D GeMM算法称为MeshSlice。图4可视化了先前算法的时间线,并与MeshSlice进行了比较。该图显示了计算、行间通信和列间通信的时间进展。Cannon需要进行倾斜操作且只支持方形网格形状,因此其流量高于其他算法,增加了总执行时间。SUMMA使用低效的bcast/reduce通信操作,由于细粒度的数据包而产生流水线气泡和同步开销。Collective算法不将集体通信与计算重叠。Wang的算法只划分了一个方向上的集体通信,因此另一个方向上的通信没有被重叠。最后,MeshSlice能够在两个方向上都将通信与计算重叠,从而实现最快的执行速度。

图4:五种2D GeMM算法的时间线对比:Cannon、SUMMA、Collective、Wang和MeshSlice。
图4:五种2D GeMM算法的时间线对比:Cannon、SUMMA、Collective、Wang和MeshSlice。

3.1 MeshSlice 2D GeMM 算法

3.1.1 MeshSlice算法的数学描述

公式1
公式1

算法1
算法1

3.1.2 MeshSlice算法的详细实现

公式2
公式2

公式3
公式3

公式4
公式4

3.2 MeshSlice LLM 自动调优器

3.2.1 阶段1:数据流和分片

3.2.2 阶段2:网格形状和切片数

A4 实验环境

A5 实验结果

5.1 分布式GeMM算法性能

5.2 LLM自动调优器和成本模型

5.3 在真实硬件上的MeshSlice性能

A6 结论

本文提出了MeshSlice算法,一种为分布式DNN训练设计的高效2D张量并行方法。MeshSlice通过将通信操作切分为多个部分,并利用软件流水线在行列两个维度上都实现了通信与计算的高效重叠,从而解决了现有2D GeMM算法(如Cannon、SUMMA、Collective GeMM)存在的流量大、同步开销高或无法重叠等问题。此外,本文还设计了MeshSlice LLM自动调优器,该工具能够通过选择高效的数据流,并利用精确的成本模型协同优化加速器网格形状和通信粒度,从而自动化了复杂的性能调优过程。

在模拟的256个TPUv4集群上的评估表明,MeshSlice在训练GPT-3和Megatron-NLG模型时,端到端性能分别比当前最先进的算法快12.0%和23.4%。

未来的工作方向包括:
1. 扩展到GPU集群:通过在GPU集群的物理网络上构建逻辑网格,将MeshSlice应用于更广泛的硬件平台,并相应调整自动调优器以考虑网络竞争。
2. 应用于推理场景:调整MeshSlice及其自动调优器以适应推理任务中更可能出现的内存瓶颈。
3. 支持其他DNN层:将MeshSlice应用于可转换为GeMM操作的其他层,如卷积层,或用于优化GNN中的2D分布式稀疏GeMM。
4. 结合专家混合(MoE)模型:将MeshSlice的2D TP与MoE的专家并行(EP)相结合,以支持更大规模模型的训练。