TensorRT-LLM Large-scale Expert Parallelism Optimizations

Enwei Zhu (朱恩伟), NVIDIA 加速计算专家团队 高级工程师
Jinyang Yuan (袁劲飏), NVIDIA 加速计算专家团队 高级工程师

议程 (Agenda)

动机:为何需要大规模专家并行 (Large-scale EP)?

大规模专家并行(EP)可以在以下方面提供帮助:

  1. 提升计算强度:在给定的单GPU工作负载下,扩展EP规模可以减少分组GEMM(通用矩阵乘法)的内存访问,从而提高计算强度。
  2. 改善内存密集型操作性能:对于分组GEMM中受内存带宽限制的范围(即高每用户吞吐率/tps/user),大规模EP可以提升其性能。
  3. 提升并发能力:对于受内存容量限制的GPU,大规模EP可以通过实现更高的并发度,在帕累托曲线的左侧区域(高并发)提供帮助。

下图展示了在GB200上的预期帕累托曲线,比较了EP规模为4和8时的性能差异。

Page 3 - GB200上的EP规模性能帕累托曲线
注:此图仅供技术讨论和参考,实际性能可能因产品组合不同而异。

大规模专家并行的挑战

通信 (Communication)

在不同并行策略组合下,通信开销是关键挑战。

Page 5 - 不同EP规模下的通信模式
Page 5 - 不同EP规模下的通信模式

AllGather + ReduceScatter 模式

这种模式下,每个令牌(token)的隐藏状态(hidden states)会被发送到所有的计算排名(rank),但实际上只有拥有被选中专家(topK experts)的rank才需要这些数据,造成了通信冗余。

Page 6 - AllGather + ReduceScatter 通信流程图
Page 6 - AllGather + ReduceScatter 通信流程图

AlltoAllv 模式

AlltoAllv是一种更高效的通信原语。在这种模式下,每个令牌的隐藏状态只会被发送到那些拥有其所需专家的rank,从而减少了不必要的数据传输。

Page 7 - AlltoAllv 通信流程图
Page 7 - AlltoAllv 通信流程图

专家负载不均衡 (Expert Load Imbalance)

专家负载不均衡是大规模EP面临的另一个核心挑战。

下图展示了在翻译任务(模型:DeepSeek-R1,EP规模:32)中,不同rank和层的令牌分配数量热力图,以及特定rank上专家负载的柱状图,清晰地揭示了负载不均衡现象。

Page 8 - 专家负载不均衡的可视化分析
Page 8 - 专家负载不均衡的可视化分析

专家并行负载均衡器 (Expert Parallelism Load Balancer)

为了解决专家负载不均衡问题,我们设计了专家并行负载均衡器(EPLB)。

离线负载均衡器 (Offline Load Balancer)

设计思路

实验结果

下图对比了禁用和启用EPLB后的专家负载情况。启用后(右图),负载分布明显更加均匀。

Page 11 - 离线负载均衡器效果对比
Page 11 - 离线负载均衡器效果对比

在线负载均衡器 (Online Load Balancer)

观察

离线均衡器依赖于数据分布的稳定性。然而,实际场景中数据分布可能动态变化,这促使我们研究在线负载均衡。

Page 12 - 不同数据集中专家负载模式的迭代相似性
Page 12 - 不同数据集中专家负载模式的迭代相似性
Page 13 - 单个请求内及不同请求间的专家负载模式
Page 13 - 单个请求内及不同请求间的专家负载模式

设计思路

基于以上观察,我们设计了在线负载均衡器,以适应动态变化的负载模式。

后台更新机制

Page 15 - 在线负载均衡器的后台更新机制示意图
Page 15 - 在线负载均衡器的后台更新机制示意图

在线负载均衡器:设计

在线负载均衡器通过在运行时动态调整专家(Expert)的放置和权重,以解决负载不均衡问题。其工作流程涉及CPU和GPU的协同,具体如下:

Page 16
Page 16

在线负载均衡器:结果

在线负载均衡器能有效解决负载不均衡问题。实验结果表明:
- 在权重更新后,负载在不同的槽(slots)/等级(ranks)之间变得均衡,不再出现热点槽/等级。
- 下图展示了在翻译数据集上的效果,其中红色横线表示权重更新发生的层。更新前,存在明显的负载不均衡(黄色条带);更新后,负载分布变得非常均匀(紫色/蓝色区域)。
- 实验设置:翻译数据集,ep_size=32, local_bs=256, num_slots=288 (32个冗余专家)。

Page 17
Page 17

专家并行负载均衡器:结论

下表对比了离线和在线两种负载均衡器的优缺点及适用场景:

Page 18
Page 18

通信优化及其他

在解决了负载均衡问题后,下一个优化重点是通信开销。

通信优化:基线分析 (AllGather + ReduceScatter)

在优化前,基线采用 AllGather + ReduceScatter 的通信模式。性能分析显示:
- 计算部分:MoE的分组GEMM(2 moe_gemm)和Attention(attn)核函数的执行时间表现良好。随着专家并行(EP)规模的增大,GEMM的耗时甚至会减少,而Attention的耗时保持不变。
- 通信部分AllGatherReduceScatter操作的耗时随着EP规模的增大而显著增加。然而,在每个GPU批处理大小固定的情况下,所需的通信消息大小应为常数,这表明当前的通信方式存在优化空间。

Page 20
Page 20
Page 21
Page 21

性能剖析图显示,AllGather(152us + 149us)和 ReduceScatter(390us)占用了大量的执行时间,成为性能瓶颈。

Page 22
Page 22

通信优化:AlltoAll 初步实现

为了优化通信,引入了 AlltoAll 操作,其特点包括:
- 通过NVLink进行点对点(P2P)访问。
- 支持设备内存中的发送/接收计数,实现 AlltoAll 风格的通信。
- 支持CUDA Graph。

新的流程变为 Prepare -> A2A -> MoE -> A2A -> Local Reduce。性能剖析显示,All2All 的通信时间(fp4: 18us+12us, bf16: 174us)远低于之前的 AllGatherReduceScatter。但该实现可能会因 recv_m ~selectedExpert * m 而引入新的负载不均衡问题。

Page 23
Page 23

通信优化:AlltoAll 优化

在初步实现的基础上,进行了三项进一步的优化:
1. 使用 AlltoAll 重构 prepare AllGather:将用于准备专家ID和令牌尺度的AllGather操作也替换为更高效的AlltoAll
2. 低精度合并:在发送端进行量化(Quant),在接收端进行反量化(DeQuant),以减少数据传输量。
3. 融合 AlltoAll 核函数:将多个小的AlltoAll核函数融合成一个,减少核函数启动开销。

Page 24
Page 24

优化1的效果显著,如下图左侧所示,“current”(当前)版本的耗时远低于“previous”(之前)版本。优化2和3的实现细节如右侧图所示,通过共享内存(SM)和P2P工作空间进行高效的数据交换。

Page 25
Page 25

其他优化

MoE 辅助核函数优化

Page 26
Page 26
Page 27
Page 27
Page 28
Page 28

MTP LM Head 张量并行 (Tensor Parallelism)

Page 29
Page 29
Page 30
Page 30

端到端性能

Page 33: ISL/OSL = 1k/1k, MTP 禁用时的吞吐量分析图。该图表展示了在不同专家并行(EP)等级(Rank 4, 8, 16, 32)下,每个生成GPU的输出吞吐量(Y轴)随每用户吞吐量(X轴)变化的帕累托曲线。
Page 33: ISL/OSL = 1k/1k, MTP 禁用时的吞吐量分析图。该图表展示了在不同专家并行(EP)等级(Rank 4, 8, 16, 32)下,每个生成GPU的输出吞吐量(Y轴)随每用户吞吐量(X轴)变化的帕累托曲线。
Page 34: ISL/OSL = 8k/1k, MTP 禁用时的吞吐量分析图。该图表展示了在不同专家并行(EP)等级(Rank 4, 8, 16, 32)下,每个生成GPU的输出吞吐量(Y轴)随每用户吞吐量(X轴)变化的帕累托曲线。
Page 34: ISL/OSL = 8k/1k, MTP 禁用时的吞吐量分析图。该图表展示了在不同专家并行(EP)等级(Rank 4, 8, 16, 32)下,每个生成GPU的输出吞吐量(Y轴)随每用户吞吐量(X轴)变化的帕累托曲线。
Page 35: ISL/OSL = 8k/1k, MTP 启用时的吞吐量分析图。该图表展示了在不同生成等级(Gen Rank 4, 8, 16, 32)下,每个生成GPU的输出吞吐量(Y轴)随每用户吞吐量(X轴)变化的帕累托曲线。
Page 35: ISL/OSL = 8k/1k, MTP 启用时的吞吐量分析图。该图表展示了在不同生成等级(Gen Rank 4, 8, 16, 32)下,每个生成GPU的输出吞吐量(Y轴)随每用户吞吐量(X轴)变化的帕累托曲线。