TokenWeave: Efficient Compute-Communication Overlap for Distributed LLM Inference

作者/机构: Raja Gond, Nipun Kwatra, and Ramachandran Ramjee (Microsoft Research India)

A1 主要贡献

核心问题: 大型语言模型(LLM)的分布式推理即使在使用如NVLink等高速互连的GPU上,也可能引入高达20%的通信开销。现有缓解这些开销的技术通常将计算分解为更细粒度的任务,并将通信与子任务重叠,但这种细粒度分解会在GPU上因波次量化效应(wave quantization effects)而产生额外开销,同时通信本身也占用大量流式多处理器(SM),进一步增加了开销。因此,目前主流的开源服务系统(如vLLM、SGLang、TensorRT-LLM)均不支持对张量并行(tensor-parallel)模式下的大模型进行计算-通信重叠。

研究目标: 本文旨在提出一种高效的方法TokenWeave,以最小化LLM推理中的通信开销,特别是在中等规模token批次下也能实现显著性能提升,并解决现有细粒度分解方法带来的计算效率低下和资源竞争问题。

创新点:
TokenWeave通过三项关键技术来应对上述挑战:

  1. 令牌分割(Token-Splitting): 这是一种粗粒度的令牌分割技术,它将推理批次中的令牌以感知GPU波次(wave-aware)的方式,划分为两个大致相等的子集。这种方法允许一个子集的通信与另一个子集的计算重叠。通过“智能分割”(Smart-splitting)策略,确保两个子任务的内核执行的总波次数不超过完整计算内核的波次数,从而几乎完全消除了因计算规模变小而产生的开销。

  2. 优化层归一化(Layer Normalization): 现有工作(如Tilelink, Flux, Nanoflow)未对归一化操作进行优化,但本文发现RMSNorm操作在H100上可能占总延迟的8%。TokenWeave通过将RMSNorm操作的计算顺序调整到AllReduce操作的ReduceScatter和AllGather之间,并将其与通信操作融合,来优化归一化计算。

  3. 新颖的融合AllReduce–RMSNorm内核: TokenWeave利用现代硬件特性(Hopper和Blackwell GPU上的Multimem指令)实现了一个创新的融合内核。该内核将通信(AllReduce)和RMSNorm操作结合起来,仅需2-8个SM即可完成,远低于先前工作所需的16-20个SM。这不仅减少了SM占用,还使得内存带宽受限的RMSNorm操作可以与另一个批次的计算重叠,带来额外增益。

核心成果:
* TokenWeave在多种模型和工作负载下,实现了高达1.29倍的延迟降低和1.26倍的吞吐量提升。
* 在多个场景中,TokenWeave的性能甚至优于一个移除了所有通信的等效模型(vllm-nocomm),这得益于其融合的AllReduce–RMSNorm内核。
* 与现有技术(如TileLink)在小序列长度下会产生性能下降不同,TokenWeave即使在1024令牌的序列长度下也能实现1.18倍的延迟改进,使其能与现代调度器(如Sarathi)中的分块预填充(chunked prefill)技术结合使用。


图1. 三种模型在8xH100 DGX上AllReduce的通信开销与序列长度的关系(误差棒显示5次运行的标准差)。即使有NVLink/NVSHARP,通信开销也可能超过20%。


图2. Llama-3.3-70B在8xH100 DGX上不同序列长度的推理延迟。vllm-Multimem对应于使用Multimem和NVSHARP支持的优化AllReduce实现的vLLM。vllm-nocomm是一个反事实基线,仅对应于没有任何通信的计算时间。虚线显示了相对于vllm-Multimem基线的归一化性能。TokenWeave实现了高达1.29倍的加速。即使在较短的序列长度下,TokenWeave也提供了显著的增益,例如,在1K tokens的序列长度下为1.18倍,而先前的方案会产生开销。在序列长度≥4K时,TokenWeave的性能优于vllm-nocomm,不仅完全弥补了通信开销,还因我们融合的AllReduce–RMSNorm内核而提供了额外增益。


图3. 不同模型在8xH100 DGX上RMSNorm开销与序列长度的关系(误差棒显示5次运行的标准差)。在AllReduce之后执行的RMSNorm具有不可忽略的开销,范围可能在5-9%之间。


图4. 在8xH100 DGX上,针对8192维隐藏大小、bf16精度,三种方法在不同序列长度下执行单个AllReduce和RMSNorm操作的延迟。简单地在ReduceScatter (RS)和AllGather (AG)之间重排RMSNorm,除了在非常高的序列长度下,其性能比在AllReduce之后执行RMSNorm更差,因为将AllReduce (AR)拆分为RS和AG的开销抵消了收益。我们融合的AllReduce–RMSNorm内核在整个序列长度范围内实现了高达1.40倍的改进。

A3 背景知识与相关工作

Transformer解码器架构。标准的仅解码器Transformer【14, Attention Is All You Need, 2017, arXiv:1706.03762 [http://cs.CL]】通过多个Transformer块处理输入序列,每个块由一个注意力(Attention)层和一个前馈网络(FFN)层组成。在每个注意力层和FFN层之前都会插入一个RMSNorm层。FFN层通常包含两个线性变换,中间夹杂一个非线性激活函数如GELU【6, Gaussian Error Linear Units (GELUs), 2023, arXiv:1606.08415 [cs.LG]】。注意力层由多个注意力头组成,首先执行QKV-preprojection操作为每个头计算查询(query)、键(key)和值(value),然后在每个头中执行自注意力操作,最后通过一个post-projection步骤合并所有头的输出。

2.1 分布式推理

分布式推理的必要性。由于许多现代LLM模型尺寸巨大,推理必须在多个GPU上运行才能容纳模型参数。即使模型能装入单个GPU,为了满足交互式工作负载的严格延迟服务等级目标(SLO),也可能需要分布式推理。此外,在许多情况下,分布式推理更高效,因为它通过释放KV缓存的内存来允许更大的批处理大小。

张量并行(TP)策略。对于分布式推理,最常见的并行策略是张量并行(TP)。在TP中,FFN的权重矩阵在GPU之间进行分区——第一个MLP按列划分,第二个按行划分【13, Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism, 2019, arXiv:1909.08053 [http://cs.CL]】。每个GPU使用本地的部分权重矩阵进行计算。然后,各个输出通过AllReduce操作进行组合,以获得最终输出。对于注意力层,分区是沿着头的维度进行的。每个GPU为本地的头执行QKV-preprojection和自注意力操作。最终的post-projection矩阵乘法的输出最终通过AllReduce操作在GPU之间再次组合。

TP中的通信开销。因此,TP在每个Transformer块中需要两次位于关键路径上的AllReduce操作。这会显著增加推理延迟并降低GPU效率。例如,如图1所示,即使在具有高速NVLink互连的8xH100 DGX上,通信成本也可能增加高达23%的开销。

2.2 计算-通信重叠

计算-通信重叠的基本策略与挑战。减少通信开销的常用策略是将通信与其他计算重叠。然而,由于数据依赖性,待通信的数据只有在计算步骤(FFN/Attention)完成后才准备好。为了解决这个问题,一种方法是将计算分解为更小的子任务。已完成子任务的通信可以与剩余子任务的计算重叠。然而,分解为子任务会导致计算效率降低。现代GPU由于提供高度并行性,运行大型计算内核远比运行多个小型内核高效,分解为小型内核可能导致波次量化效应(wave quantization effects),即最后一波计算中可能有一部分GPU SM无工作可做。许多技术【2, FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion, 2024, arXiv:2406.06858 [cs.LG]】、【7, Breaking the computation and communication abstraction barrier in distributed machine learning workloads, 2022, Proceedings of the 27th ACM International Conference on Architectural Support for Programming Languages and Operating Systems】、【19, Overlap communication with dependent computation via decomposition in large deep learning models, 2022, Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 1】、【22, Distributed w/ TorchTitan: Introducing async tensor parallelism in PyTorch, 2024, https://discuss.pytorch.org/t/distributed-w-torchtitanintroducing-async-tensor-parallelism-in-pytorch/209487】、 【26, TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives, 2025, arXiv:2503.20313 [cs.DC]】通过融合内核实现来解决此问题,即在计算各个tile完成时,从计算内核内部协调通信。

基于XLA和TPU的重叠方法。例如,Wang等人【19, Overlap communication with dependent computation via decomposition in large deep learning models, 2022, Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 1】将AllReduce分解为ReduceScatter和AllGather操作。ReduceScatter与前一层的计算重叠,而AllGather与下一层的计算重叠。为了协调重叠,他们依赖XLA和TPU支持,从一个分块的GEMM内核内部调用异步集合通信API。然而,对XLA的依赖使得移植到PyTorch + CUDA变得困难,导致主流采用有限。

基于CUDA的融合内核方法。Flux【2, FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion, 2024, arXiv:2406.06858 [cs.LG]】为基于CUDA的实现提供了类似于【19】的解决方案。他们也把AllReduce分解为ReduceScatter和AllGather,并通过CTA级别的流式方案采用融合内核方法。每个GEMM CTA将其MMA指令与位于对等GPU上的子tile的远程TMA加载和存储交错进行。利用Hopper GPU上的TMA和NVSHMEM指令,可以使通信停顿被下一个计算warp自动填充,从而实现计算通信重叠。TileLink【26, TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives, 2025, arXiv:2503.20313 [cs.DC]】采用了与Flux相同的思想,但通过基于Triton的实现将融合内核的实现推入编译器。这大大减少了手写代码量,并允许更多的优化策略。然而,Flux和TileLink方案都导致GEMM内核的寄存器和共享内存占用增加,这可能导致CTA占用率降低并损害计算性能。

融合内核方法的三个问题。融合内核方法【2, FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion, 2024, arXiv:2406.06858 [cs.LG]】、【5, Distributed GEMM, n.d., https://blog.shi-labs.com/distributed-gemm-88be6a481e2b】、 【19, Overlap communication with dependent computation via decomposition in large deep learning models, 2022, Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 1】、【21, Introducing Async Tensor Parallelism in PyTorch, n.d., https://discuss.pytorch.org/t/distributed-w-torchtitanintroducing-async-tensor-parallelism-in-pytorch/209487】、 【26, TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives, 2025, arXiv:2503.20313 [cs.DC]】存在三个问题。首先,由于这些技术依赖于GEMM内核内的重叠,非GEMM操作(如注意力)期间的通信重叠是不可行的。例如,FFN层后的AllReduce是通过将ReduceScatter部分与第二个MLP重叠,并将AllGather部分与下一个注意力层的QKV-preprojection重叠来执行的。这限制了重叠的机会,对于小模型或小批次,QKV-projection耗时很短,没有足够的时间来重叠AllGather。其次,将操作分解为ReduceScatter和AllGather,与等效的AllReduce操作相比效率较低。图5绘制了单个AllReduce内核的性能与等效的ReduceScatter+AllGather成本在不同通信大小下的对比。如图所示,这可能增加显著的开销,高达50%以上。第三,较小的tile大小的通信粒度比单次大传输效率低。图6绘制了实现的ReduceScatter带宽与张量大小的关系。如图所示,较小的通信大小实现的带宽要低得多。由于这些因素,【2】、【19】、【26】只有在GEMM足够大时才能有效隐藏通信,这要求批次中包含大量的token。


图5. 将AllReduce (AR) 拆分为ReduceScatter (RS) 和AllGather (AG) 会导致不可忽略的开销。图中显示了在8xH100 DGX上这些操作的各自时间和相对性能(线图)。所有运行的隐藏层大小为8192,精度为bf16。


图6. 大型集合操作更高效。图中显示了在8xH100 DGX上,不同序列长度下(隐藏层大小8192,bf16)ReduceScatter (RS)的带宽。更大的张量带来更好的带宽,表明将输入拆分为更小的部分会导致开销。

NanoFlow的调度方法。NanoFlow【28, NanoFlow: Towards Optimal Large Language Model Serving Throughput, 2024, arXiv:2408.12757 [cs.DC]】从调度的角度解决这个问题:它将一个传入的批次切分成纳米批次(nano-batches)——在整个内核(FFN、注意力、集合通信)的粒度上,而不是CTA tile——并将每个纳米批次分配给一个绑定到固定SM子集的专用CUDA流。通过协同调度资源配置文件互补的纳米批次(例如,计算密集型的FFN与通信受限的ReduceScatter),NanoFlow重叠了GPU计算、HBM流量和NVLink传输。然而,这种方法依赖于高批次大小,以便将输入批次分解为足够大的纳米批次,因为纳米批次的较小内核可能导致显著的开销。

同期工作DeepSeek的比较。在同期的工作中,DeepSeek在其推理系统代码库中使用了将一个批次分成两部分的想法【3, Two batch overlap, n.d., DeepSeek, https://github.com/deepseek-ai/profile-data】来实现计算-通信重叠。然而,由于他们的推理系统避免了张量并行,而是依赖于专家并行,他们需要支持更昂贵的all-to-all通信,而不是张量并行所需的更廉价的all-reduce。因此,他们的系统将一个批次的all-to-all通信与第二个批次的计算重叠,通信需要约20个SM,而TokenWeave中只需要4-8个SM。此外,他们没有将层归一化与通信一起重叠 。

2.3 NV-SHARP、多播和SymmetricMemory

Hopper GPU的硬件特性。Hopper GPU上可用的第四代NVSwitch系统(NVLink4)集成了专用的SHARP(可扩展分层聚合和归约协议)引擎,称为NVLink SHARP或NVLS。NVLS使GPU能够向多播地址发出multimem PTX加载/存储指令。这些指令利用交换机结构来(i)将数据包复制到每个订阅的GPU,以及(ii)在转发聚合结果之前执行网络内归约。因为算术运算直接在交换机ASIC内执行,通信集合操作显著减少了NVLink带宽使用和GPU SM资源消耗。

PyTorch的SymmetricMemory API。PyTorch 2.6.0通过其SymmetricMemory API【20, PyTorch SymmetricMemory: Harnessing NVLink Programmability with Ease, 2025, https://dev-discuss.pytorch.org/t/pytorch-symmetricmemoryharnessing-nvlink-programmability-with-ease/2798/1】暴露了NVLS。SymmetricMemory通 过symm_mem.empty调用(类似于torch.empty())方便地在GPU上分配对等缓冲区。在缓冲区分配后,通过对symm_mem.rendezvous()的集体调用来交换内存句柄,将对等缓冲区映射到每个参与GPU的虚拟地址空间中。rendezvous之后,每个GPU可以使用标准的内存操作在Triton或CUDA内核中访问远程或多播指针,从而无需显式的NCCL调用,并大大简化了通信例程的实现。

硬件和软件进步带来的优势。这些硬件和软件的进步显著减少了执行通信原语所需的SM数量,并减轻了内存带宽压力。在我们的实验中,我们观察到,仅使用H100 GPU上6-8%的SM就足以饱和通信带宽,使得大部分SM可以用于与通信重叠的计算任务。这可以从图7中看到,该图显示了AllReduce内核的延迟与SM数量的关系。


图7. 基于Multimem的AllReduce实现需要非常少的SM。图中显示了在不同SM数量下,AllReduce multimem内核在不同序列长度(隐藏层大小8192,bf16)下的性能。在大多数情况下,4-8个SM就足够了。

A2 方法细节

我们现在描述TokenWeave的主要技术——一种粗粒度、智能的令牌分割(Token-Splitting)技术,残差加法-RMSNorm(ResidualAdd–RMSNorm)重排序,以及一个融合的AllReduce–RMSNorm实现。图8提供了我们方法与标准张量并行实现的高层示意图对比。

3.1 粗粒度令牌分割

基本分割策略。TokenWeave将传入的批次划分为两个子集,每个子集具有几乎相等的计算和通信需求。为了说明,考虑一个大小为1且包含n个令牌的序列的传入批次。这个批次被分为两个分割批次:一个包含长度为n1的初始子序列的前缀分割(prefix-split),和一个包含长度为n2的剩余子序列的后缀分割(suffix-split)(n = n1 + n2)。这两个分割批次以流水线方式分别处理。

处理注意力层的依赖性。除注意力外,所有Transformer操作都是令牌级别的,分割处理不会带来问题。然而,注意力操作引入了依赖性,因为后缀子序列中令牌的注意力计算依赖于前缀子序列中的令牌。

依赖性解决方案。为了处理这种依赖性,TokenWeave采用了一种分块注意力(chunked attention)实现【1, Taming {ThroughputLatency} tradeoff in {LLM} inference with {Sarathi-Serve}, 2024, 18th USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)】,并确保前缀分割的操作先于后缀分割的操作。当处理大于1的批次时,分区可能包含完整或部分序列。TokenWeave确保部分序列的所有前缀都位于前缀分割内,从而保持必要的计算依赖性。

3.1.1 感知波次的智能分割

波次量化开销问题。将大型计算划分为较小单元会在GPU上因波次量化效应(wave quantization effects)而引入开销。考虑一个需要300个CTA(Cooperative Thread Arrays)的GEMM内核。在一台拥有132个SM(Streaming Multiprocessors)的NVIDIA H100 GPU上,假设每个CTA恰好占用一个SM,这个计算将跨越两个完整的波次和一个使用36个SM的部分波次。因此,总计算时间等于三个波次的执行时间。

朴素分割的弊端。如果这个计算被平均分成两个各150个CTA的批次,每个较小的批次现在需要两个波次:一个132个SM的完整波次,后面跟着一个18个SM的部分波次。因此,这两个较小的计算总共需要四个波次,从而增加了与原始未分割计算相比的执行时间。

智能分割策略。为了防止这种开销,智能分割(Smart-splitting)采用了一种感知波次的分割策略,确保两个分割所需的总波次数不超过原始未分割计算的波次数。在上面的例子中,智能分割策略性地将批次划分为一个包含132个CTA的分割(恰好一个完整波次)和另一个包含168个CTA的分割(一个完整波次和一个部分波次)。这种方法有效地保持了总计算波次数,最小化了因分割带来的任何波次量化开销。图9比较了有无智能分割的FFN层的延迟。如图所示,智能分割可以减少分割开销,特别是对于令牌较少的批次。

3.1.2 重叠执行

流水线式重叠执行。这两个批次的操作现在可以如图8所示进行重叠。如图所示,当第一个批次的AllReduce正在处理时,我们计算第二个批次的注意力。然后,第一个批次的FFN与第二个批次的AllReduce重叠,依此类推。我们通过CUDA流实现这种重叠执行。一个通信流处理所有的通信操作,而一个计算流运行计算操作。通过使用torch.cuda.stream_wait(stream, wait_stream)(它不涉及cpu)在计算流和通信流之间执行轻量级同步来处理数据依赖。


图8. TokenWeave概览。(a) 普通张量并行:所有计算和通信操作按顺序执行。(b) TokenWeave:输入批次被划分为两个部分。RMSNorm与通信融合,一个部分的通信与另一部分的计算重叠。独立的计算流和通信流交织在一起以协调重叠。


图9. 智能分割可以减少波次量化开销。图中显示了FFN层在不分割、均等分割和智能分割情况下的延迟。同时显示了相对于不分割情况的归一化时间。智能分割减少了分割开销,特别是对于较小的批次。

3.2 RMSNorm重排序

现有实现的冗余。如第2.1节所述,通过张量并行进行分布式推理,在Attention和FFN计算之后涉及一个AllReduce操作。在Transformer模型中,残差加法(Residual Addition)和RMSNorm操作通常紧随这些层之后,并由每个GPU独立计算。这种独立计算导致了冗余,因为在AllReduce之后所有GPU都拥有相同的令牌嵌入。

TokenWeave的重排序策略。为了解决这种低效率问题,TokenWeave策略性地在AllReduce过程中重排RMSNorm操作。具体来说,AllReduce操作可以分解为ReduceScatter和AllGather操作。在ReduceScatter步骤完成时,每个GPU拥有张量的1/p的完整和最终状态,其中p表示分布式副本中的GPU总数。因此,每个GPU可以独立地对其专用的1/p部分执行RMSNorm,而没有冗余。由于RMSNorm是令牌级别的操作,我们只需要确保ReduceScatter仅在令牌边界处分割张量。这确保了每个GPU拥有完整的令牌嵌入。随后的AllGather操作然后将这些经过RMSNorm处理后的值分发给所有GPU。

重排序的优势与挑战。通过这种重排序,TokenWeave将RMSNorm的计算量减少了p倍,从而消除了不必要的冗余。然而,如图4所示,这样简单的重排序实际上可能导致性能损失,因为将AllReduce分解为独立的ReduceScatter和AllGather的成本抵消了RMSNorm计算的收益。我们在下一节通过我们的融合内核来解决这种低效率问题。

3.3 融合的AllReduce–RMSNorm实现

利用硬件特性实现融合。TokenWeave利用NVIDIA H100的multimem能力,实现了一个高效、融合的ReduceScatter、RMSNorm和AllGather操作。具体来说,在ReduceScatter期间,每个GPU使用NVSHARP对其张量的1/p部分执行归约。然后我们立即在这个归约后的部分上执行RMSNorm计算,结果随后被写入multimem地址,用于AllGather分发给所有GPU。

融合内核的实现细节。我们在列表1中提供了我们融合内核的源代码。RMSNorm通常需要两次HBM读取——一次用于计算令牌嵌入的方差,另一次用于将值乘以计算出的方差进行缩放,以及最后一次HBM写入。相比之下,我们的融合实现通过直接在multimem ReduceScatter的结果(第23行)上计算方差(第25行)来优化内存访问,从而消除了初始HBM读取的需要。此外,我们通过直接将归一化后的值输出到multimem以进行AllGather操作(第36行),节省了一次额外的HBM写入。此外,请注意,残差操作也与RMSNorm融合在一起(第24行)。

template <typename scalar_t, int width>   
_global__ fused_rs_ln_ag_cta_kernel(...) {   
    const int vec_hidden_size = hidden_size / width;   
    int tokens_per_cta = (num_tokens + gridDim.x - 1) / gridDim.x;   
    sync_remote_blocks<MemOpSem::Relaxed>(signal_pads, rank, world_size);   
    __syncthreads();   
    
    for (int iter = 0; iter < tokens_per_cta; iter++) {   
        int token_id = blockIdx.x + iter * gridDim.x;   
        if (token_id >= num_tokens) continue;   
        
        float variance[1] = {0.0f};   
        shared__ float s_variance;   
        int offset = token_id * vec_hidden_size;   
        int offset_scalar = token_id * hidden_size;   
        auto input_o = input_v + offset;   
        auto residual_o = residual_v + offset;   
        
        for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {   
            auto multimem_temp = multimem_ld_reduce_add<16>(multimem_address_ptr +   
                offset_scalar + idx * width);   
            vec_t temp = *(reinterpret_cast<vec_t*>(&multimem_temp));   
            temp += residual_o[idx];   
            variance[0] += temp.sum_squares();   
            residual_o[idx] = temp;   
        }   
        
        blockReduceSum<float, 1>(variance);   
        if (threadIdx.x == 0)   
            s_variance = rsqrtf(variance[0] / hidden_size + epsilon);   
        __syncthreads();   
        
        for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {   
            vec_t temp = residual_o[idx] * s_variance * weight_v[idx];   
            multimem_st<16>(mcptr + offset + idx * width,   
                *(reinterpret_cast<Vec<16>*>(&temp)));   
        }   
    }   
    _syncthreads();   
    sync_remote_blocks<MemOpSem::AcqRel>(signal_pads, rank, world_size);   
}

融合内核的效率。内存访问的减少,加上冗余计算的消除,使得我们的融合AllReduce–RMSNorm内核非常高效。如图4所示,在所有序列长度上,该融合内核相比当前分离的AllReduce + RMSNorm计算方法提供了高达1.40倍的改进。

低SM占用与重叠优势。此外,计算和内存带宽需求的降低,使得这个融合内核可以用非常少的SM(在我们的实验中仅为2-8个)来执行,而不会产生太多开销——如图10所示,内核执行时间在8个SM之后没有太大改善,对于更长的序列,甚至4个SM就足够了。这使我们能够将一个分割批次的完整AllReduce–RMSNorm操作与另一个分割的计算重叠,不仅实现了通信的重叠,还实现了内存受限的RMSNorm的重叠。由于这个计算与另一个分割批次重叠,即使使用较少的SM(例如2个而不是8个)会导致一些开销,只要其执行时间不超过重叠的计算内核的执行时间,这些开销就会被隐藏。


图10. 融合的AllReduce–RMSNorm内核在很少的SM下表现最佳。图中显示了在不同SM数量下,AllReduce multimem内核在不同序列长度(隐藏层大小8192,bf16)下的性能。在大多数情况下,8个SM接近最佳。

A4 实验

实验环境

实验结果

延迟收益评估


图11. TokenWeave延迟增益。图中显示了在(a) 8xH100和(b) 4xH100上,不同模型处理不同序列长度的预填充请求的执行时间。在几乎所有情况下,TokenWeave都接近或优于理论上的vllm-nocomm基线(零通信开销),表明TokenWeave不仅恢复了所有通信开销,还因RMSNorm融合提供了额外增益。

吞吐量增益评估


图12. TokenWeave在端到端工作负载追踪中的吞吐量增益。图中显示了在(a) 8xH100和(b) 4xH100上,不同模型在固定(输入,输出)长度追踪以及ShareGPT追踪下的测量吞吐量。


图13. TokenWeave在不同块大小下的端到端工作负载追踪吞吐量增益。图中显示了在8xH100 DGX上,Llama-3.3-70B在固定(输入,输出)长度追踪以及ShareGPT追踪下的测量吞吐量。块大小从1024变化到8192。

与TileLink的比较


图14. Llama-3.3-70B在8xH100 DGX上的单层延迟。数字表示与vllm-multimem相比的归一化性能。在短序列长度下,TileLink最终会产生开销,而TokenWeave在整个序列长度范围内提供了一致的高增益。

与NanoFlow的比较


图15. NanoFlow端到端工作负载追踪的吞吐量评估。图中显示了在8xH100 DGX上,Llama-3.3-70B在固定(输入,输出)长度追踪下的测量吞吐量。nanoflow-full对应完整的NanoFlow实现,而nanoflow-frameworkonly禁用了NanoFlow(纳米批处理和重叠),但使用其自定义服务框架。

消融研究


图16. TokenWeave消融实验。图中显示了在(a) 8xH100和(b) 4xH100上,不同模型处理不同序列长度的预填充请求的执行时间。TokenWeave-fuseonly因消除了RMSNorm计算的冗余和中间内存访问而提供增益;而TokenWeave则通过计算-通信重叠提供了额外增益。

A5 结论

本文发现,尽管有NVLink和NVSHARP等高速硬件支持,在多GPU上服务的大型模型的通信成本仍然高达20%。此外,RMSNorm操作也带来了5-9%的显著开销。为了解决这些问题,本文提出了TokenWeave,它将模型输入分割为两个大致相等的批次,并将一个批次的计算与另一个批次的新颖的融合AllReduce-RMSNorm通信和归一化内核重叠。通过在4xH100和8xH100 DGX GPU上对多种模型进行广泛的实验评估,我们表明,在各种设置和工作负载下,与优化的基线相比,TokenWeave实现了高达1.29倍的延迟降低和1.26倍的吞吐量提升。