Petrick Liu, Jiang Shao, NVIDIA DevTech Team | AI Open Day / 2025.05.30
在Hopper架构之前,典型的分块通用矩阵乘法(Blocked GEMM)通过将数据在不同层级内存(全局内存、共享内存、寄存器文件)之间移动,并利用CUDA核心/张量核心(Tensor Cores)进行计算。其基本流程是通过多级分块(Thread Block Tile, Warp Tile, Thread Tile)来管理数据。
如上图所示,基本的计算循环包含:
1. 从共享内存加载数据A (load_A_tile) 和数据B (load_B_tile) 到寄存器文件。
2. 线程同步 (__syncthreads())。
3. 从寄存器文件加载片段 (load_A_frag, load_B_frag)。
4. 执行矩阵乘加运算 (mma)。
5. 再次线程同步。
这种模式中,数据加载(LDGSTS,从全局内存到共享内存)和计算(MMA)交替进行,导致张量核心(TC Active)存在空闲时间,因为计算必须等待数据加载完成。
为了建立一个稳定的计算流水线,需要一个“序言”(Prologue)阶段来预加载数据。如下图所示,在计算开始前,需要执行一系列的全局内存加载指令(LDGSTS Ktile0 至 Ktile4)来填充数据缓冲区。这个过程会引入启动延迟(Gmem Latency),在此期间张量核心处于非活动状态。
为了隐藏和重叠异步操作的延迟,需要精心设计的流水线。关键的优化思路是:
- 采用高度流水线化的设计来最大化吞吐量。
- "序言"阶段是实现高吞吐量所需付出的代价。
- 在主循环(Mainloop)中,通过RF双缓冲(RF double buffer)等技术,实现全局内存加载(Gmem Loading)、共享内存加载(Smem Loading)与张量核心计算(TC computing)的完全重叠。
不同架构在处理计算任务时采用了不同的调度和资源分配策略,这直接影响了效率和延迟隐藏。
Ampere架构:
Hopper架构:
Hopper架构引入了Mbarrier,这是一种用于Warp间通信和同步的强大机制,是实现高效异步编程的关键。它支持创建生产者-消费者(Producer-Consumer)模型,特别是在Warp专业化的GEMM中。
Hopper Warp专业化GEMM流程:
生产者Warp (TMA Warps):
CollectiveMma::load_a等接口。Smem_empty 屏障,确保消费者已使用完上一批数据。Smem_full 屏障。Smem_full 屏障,通知消费者数据已准备好。消费者Warp (TC Warps):
CollectiveMma::mma等接口,并且是持久化的。Smem_full 屏障,直到生产者准备好数据。Smem_empty 屏障,通知生产者可以加载新数据。共享内存作为数据缓冲区,通过Mbarrier对象(包含Smem_empty和Smem_full状态)进行同步,实现了数据加载和计算的高度流水化。
以下通过一个逐步示例解释Mbarrier的工作机制。
1. 初始化
Mbarrier对象。Init_mbarrier(&bar, 1)进行初始化,设置预期到达数量为1。while (try_wait(&bar, phase)) {}循环,等待phase翻转。Phase = 0, Expect Arrv_Cnt = 1, Actual Arrv_Cnt = 0, Expect Trans_Bytes = 0, Actual Trans_Bytes = 0。2. TMA发出加载和到达指令
- TMA Warp中的一个线程 (if(tma_thread)) 发出TMA_bulk_load指令,请求加载16KB数据。
- 接着发出mbarrier_arrive_expect(&bar, 16KB),通知Mbarrier本次事务预期传输16KB数据。
- Mbarrier状态更新:Expect Trans_Bytes变为16KB。Actual Arrv Cnt根据执行线程数更新。
3. TMA数据传输
- TMA开始将数据从全局内存加载到共享内存。
- Mbarrier会追踪实际已传输的字节数。图中显示已传输1KB,Actual Trans_Bytes更新为1KB。TC Warp仍然在try_wait处阻塞。
Actual Trans_Bytes更新为4KB。4. 事务完成
- 当全部16KB数据传输完成时,Mbarrier记录整个事务完成。Actual Trans_Bytes变为16KB。
5. Phase翻转与消费者唤醒
- 由于事务完成,Mbarrier的Phase从0翻转到1。
- TC Warp的try_wait(&bar, phase)条件满足,跳出循环("Pass here!")。
- TC Warp开始使用共享内存中的数据进行WGMMA计算。
- Mbarrier状态被重置,为下一轮数据传输做准备(例如,Expect Trans_Bytes重置为0)。
该编程模式涉及一个TMA Warp(生产者)和一个或多个TC WarpGroup(消费者),通过两个mbarrier对象(bar_full 和 bar_empty)进行同步。
初始状态与第一阶段 (Page 16)
- 初始化:
- Init_mbarrier(&bar_full, 1): 初始化bar_full,期望到达数为1(来自TMA Warp)。
- Init_mbarrier(&bar_empty, 128): 初始化bar_empty,期望到达数为128(来自TC WarpGroups)。
- mbarrier_fence(): 确保初始化完成。
TMA Warp (生产者):
issue_TMA_bulk_load(...): 发出异步数据加载指令,目标是SMEM。mbarrier_arrive_expect(...): 到达bar_full屏障,表示数据正在传输中。TC WarpGroups (消费者):
while (try_wait(&bar_full, phase)) {}: 等待bar_full的阶段(phase)翻转。此时会阻塞,因为TMA的数据加载尚未完成。Mbarrier状态:
Mbarrier Smem_Full: 阶段为1,期望到达数为1,实际到达数为0。Mbarrier Smem_Empty: 阶段为0,期望到达数为128,实际到达数为0。数据就绪与消费 (Page 17)
bar_full的实际到达数变为1,满足期望值,于是bar_full的phase翻转。TC WarpGroups: 检测到bar_full的phase翻转,try_wait成功,循环退出。
WGMMA(...): 执行矩阵乘累加计算,消费SMEM中的数据。WAIT_WGMMAs(): 等待WGMMA计算完成。mbarrier_arrive(&bar_empty): 所有消费者Warp完成计算后,到达bar_empty屏障。Mbarrier状态:
Mbarrier Smem_Full: phase翻转,实际到达数与期望数均为1。Mbarrier Smem_Empty: TC WarpGroups到达后,实际到达数变为128。SMEM释放与循环 (Page 18)
bar_empty后,满足其期望到达数,bar_empty的phase翻转。TMA Warp:
while (try_wait(&bar_empty, empty_phase)) {}: 等待bar_empty的phase翻转。bar_empty的phase翻转后,try_wait成功,TMA Warp被唤醒。这意味着SMEM中的数据已被消费,可以安全地加载新数据。Mbarrier状态:
Mbarrier Smem_Empty: phase翻转,实际到达数与期望数均为128。第二次迭代 (Page 19)
phase变量递增(例如,int empty_phase = 1;)。bar_full翻转 -> TC消费数据 -> bar_empty翻转。这是一种利用Hopper架构特性(如TMA、WGMMA、Mbarrier)实现的高效GEMM计算模型,其核心思想是生产者和消费者的解耦。
生产者 (TMA Warps):
Smeme_empty 屏障,确保有空的缓冲区。Smeme_full 屏障。Smeme_full 屏障。消费者 (TC Warps):
Smeme_full 屏障,确保数据已加载到SMEM。Smeme_empty 屏障,表示该缓冲区已可重用。关键特性:
思考题: 如何初始化屏障,使得生产者可以跳过第一次对empty缓冲区的检查?
empty屏障的阶段来实现,使其看起来好像消费者已经“完成”了第一轮消费,从而立即释放第一个缓冲区给生产者。CUTLASS库提供了用于实现这种复杂流水线模型的原生组件。
cutlass/pipeline/sm90_pipeline.hpp:PipelineState结构体管理流水线的状态,包括当前阶段索引index_、阶段标志phase_和计数器count_。++运算符来推进流水线阶段,当索引绕回时翻转phase_。这与前面mbarrier的phase机制相对应。producer_try_acquire, producer_acquire (获取空闲缓冲区), producer_commit (提交数据)。consumer_try_wait, consumer_wait (等待数据就绪), consumer_release (释放已用缓冲区)。mbarrier操作,简化了异步流水线的编程。Hopper架构引入了Warp Group MMA (WGMMA),以 warp group(128个线程)为单位进行矩阵乘法。
64xNx256bit,其中N在[8, 256]之间,步长为8。4个Warp分布在M维度,每个Warp执行16xNx256bit的计算。NO_SWIZZLE, SWIZZLE_32B, SWIZZLE_64B, SWIZZLE_128B等多种内存排布模式,且与TMA的排布模式兼容。Group Commit和Wait来跟踪完成情况。一个典型的WGMMA计算流程如下:
wgmma.fence.sync.aligned;
wgmma.mma_async.aligned.m64n128k16.f32.f16.f16; ...
wgmma指令。这些指令被分组执行。wgmma.commit_group.sync.aligned;
wgmma.wait_group.sync.aligned 0;
wgmma.wait_group来同步和跟踪完成状态。Wait Smem0 Arrv)时,Tensor Core单元处于空闲(Idle)状态。Smem1的等待(Wait Smem1 Arrv)与前一个WGMMA执行重叠,消除了空闲周期。CUTLASS提供了WGMMA多级流水线的实现。代码逻辑分为两个主要部分:
MMA多级流水线序言 (Prologue):
pipeline.consumer_wait),然后执行WGMMA计算。MMA多级流水线主循环 (Mainloop):
consumer_wait) -> 对当前数据进行计算 (tiled_mma) -> 释放之前用过的缓冲区 (consumer_release) -> 推进流水线 (smem_pipe.release)。Keep MMAs in flight)。为了错开两个Warp组的执行,可以使用 OrderedSequenceBarrier。以下代码片段展示了如何命令两个数学Warp组(Math WG)的MMA(矩阵乘法累加)操作,这有助于隐藏尾声(epilogue)的开销。
代码逻辑如下:
math_wg_order.barrier.wait() 来命令两个数学Warp组的MMA操作。collective_mainloop.mma,处理主循环流水线、消费者状态、累加器等。math_wg_order.barrier.arrive() 为下一个数学Warp组的MMA做准备。Tensor Core: tcgen05.mma 家族
TMA:
im2col::w 模式加载/存储tcgen05 mma 设计Persistent:
所有以上功能均基于 mBarrier 编程
Tensor Core: Ampere 风格的 mma 家族
TMA:
im2col::w 模式加载/存储Persistent:
除Tensor Cores外,所有以上功能均基于 mBarrier 编程
tcgen05.commit 指令将 Tensor Core 的完成状态与 mBarrier 连接起来。下图展示了Warp调度器如何处理 tcgen05.mma 指令。硬件(HW)会异步跟踪TC(Tensor Core)的完成情况,并通过 mBarrier 进行更新。Wait Smem0 表示等待共享内存数据,Commit 表示提交任务,Issue 表示分发MMA指令。
CUTLASS Hopper Persistent Scheduling:
静态调度的问题:
如下图所示,当SM 2被另一个网格(Other Grid)占用时,原先分配给它的分块(如Tile 102, 202, 302)必须等待SM 2空闲后才能被处理,这造成了整体执行时间的延长。
clusterlaunchcontrol PTX指令以编程方式在SM上获取线程块/集群(Thread Block/Clusters)。clusterlaunchcontrol 实现动态持久化调度器。如下图所示,当SM 2被占用时,动态调度器会将原本分配给它的任务(如Tile 301)重新分配给其他空闲的SM(如SM 1),从而避免了执行延迟,优化了资源利用率。
这是一个SM100内核的视图,展示了不同类型的Warp(线程束)如何协作:
- Sch Warp (WarpId = 3):调度Warp,负责管理工作负载(Workld)。当工作负载为空(Workld_Empty)时,需要从外部获取输入偏移(input offset);当工作负载满(Workld_Full)时,通知工作负载流水线消费者。它也负责在需要时停止其他Warp。
- TC Warp (WarpId = 2):Tensor Core Warp,执行核心计算任务。它是TMA主循环流水线的消费者和尾声流水线的生产者。
- TMA Warps (WarpId = 0, 1):负责数据加载。它们是TMA主循环流水线的生产者。
- EpilogueWarps (WarpId = 4,5,6,7):负责尾声处理。它们是尾声流水线的消费者。
这些Warp通过共享内存(SMEM)和张量内存(TMEM)的状态(Full/Empty)以及工作负载队列进行同步和通信,形成一个高效的生产者-消费者流水线模型。
此图比较了不同架构和调度策略下的执行时间线,其中不同的Warp被分配了不同的任务(prolog, mainloop, epilog)。
Blackwell架构通过动态调度和专门的Warp任务分配,有效缩短了计算时间。
GEMM 输入:
计算流程:
scaling_B & add操作,最后再进行scaling_A & convert操作得到FP16的输出矩阵C。这种方法显著减少了缩放操作的次数。CUTLASS 示例 69: Hopper Mixed Dtype Grouped GEMM
https://github.com/NVIDIA/cutlass/tree/v3.9.2/examples/69_hopper_mixed_dtype_grouped_gemmGEMM 输入 (与前页相同):
lookup_table_convert获得)核心代码:
cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hppCUTLASS_DEVICE void mma(...)下图展示了将INT4的矩阵B先convert为FP16,再进行scaling_B,最后convert为FP8,然后与FP8的矩阵A进行MMA运算。
此页展示了对前一页流程的优化。通过使用查找表转换(lookup_table_convert),将INT4矩阵B的转换和缩放操作合并为一步(convert & scaling_B),直接生成FP8格式的数据。这减少了中间步骤和数据移动,提高了效率。
TileShape对每个组的输出矩阵C进行分区。下图展示了如何将两个组(Group 0和Group 1)的计算任务进行划分。每个组的输出矩阵C被划分为多个Block(Block 0-3等),每个Block由一个CUDA线程块负责计算。输入矩阵A和B也相应地按TileShape进行划分。
for (stage=0; ...))用于多级流水线(multistage),以及一个内层循环(for (k_block=0; ...))。下图左侧展示了线程块如何分阶段(Stage 0, Stage 1)处理K维度的不同块(k_block 0, k_block 1)。右侧图则展示了用于计算的TileShape。
为了高效执行混合精度矩阵乘法,设计了一个流水线(Pipeline)机制。其核心思想是重叠数据加载、转换和计算操作。
cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gemma_rs_warpspecialized_mixed_input.hppTileShape::K = 128和GMMA_K = 32为例,每个K Tile被划分为4个kblocks。下图展示了一个流水线执行流程。矩阵B(INT4)通过lookup_table_convert进行转换和缩放,变为FP8格式。然后与FP8格式的矩阵A进行MMA操作,累加结果为FP32。最后,对累加结果进行缩放和转换,得到FP16格式的输出矩阵C。整个过程在多个kblock上以流水线方式执行,并使用双缓冲(double buffer)来隐藏数据传输延迟。
该流程分为两个主要阶段(Stage 0 和 Stage 1),并由copy_tensors_MK、dequantize_A_kblock、cute::gemm等核函数以及warpgroup_wait同步原语协调。
Hopper架构引入了特定的Tensor Core指令 wgmma.mma_async,用于实现异步的矩阵乘法与累加操作。
mma_async允许计算与数据移动重叠。.m64n16k32。在混合专家(MoE)模型推理中,有两个关键点:
下图右侧展示了.m64n16k32指令如何将128个线程(T0-T127)组织成4个warp来执行计算。
CuTe是一个基于C++的库,用于描述和操作张量在GPU内存中的布局,是CUTLASS 3.x的核心组件。
cute::print函数可以打印出几乎所有CuTe类型的布局,包括指针、整数、步长(Strides)、布局(Layouts)和张量(Tensors)。示例定义:
TileShape -> (_:128, _:16, _:128).m64n16k32共享内存布局:
SmemLayoutA (Tensor sA):
(BLK_M, BLK_K, STAGE) -> ((_8, _16), (_128, _1), (_19))((_128, _1024), (_1, _0), (_0, _16384))SmemLayoutB (Tensor sB):
(BLK_N, BLK_K, STAGE) -> ((_8, _2), (_128, _1), (_19))((_128, _1024), (_1, _0), (_0, _2048))下图直观地展示了A和B矩阵在逻辑上的Tile划分。
使用CuTe,可以将全局的Tile划分为每个线程和warpgroup负责处理的数据分片。
// 获取当前线程/warpgroup在MMA操作中的数据分片
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
auto mma_warpgroup_slice = tiled_mma.get_warpgroup_slice(warp_group_idx);
// 为A和B分配张量片段和描述符
Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // MMA_M, MMA_N, MMA_K, PIPE
Tensor tCsA = mma_thread_slice.partition_A(sA);
...
Tensor tCrB_mma = mma_warpgroup_slice.make_fragment_B(tCsB); // MMA_M, MMA_N, MMA_K, PIPE
Tensor tCsB = mma_warpgroup_slice.partition_B(sB);
tCsA: tC代表“分区模式”,sA表示该模式应用于张量sA。sA, sB -> SMEM中的张量。rA, rF, rB -> 只是张量sB的一个视图(GmmatDescriptor)。tCsA (SMEM, 块局部):
mma_thread_slice。((_4, _2, _2), _1, (_2, _2), _1, _19)((_1, _1024, _16), _0, (32, 64), _0, _16384))tCrA_mma (RF, 线程局部):
((_4, _2, _2), _1, (_2, _2))((_1, _4, _8), _0, (_16, _32))tCrA_load (加载形状):
((_4, _2, _2), _1, (_2, _2))((_1, _4, _8), _0, (_16, _32))下图展示了数据如何从32位格式(包含4个FP8值)加载到Warp0中各个线程的寄存器中。
下图更直观地展示了数据从SMEM中的A Tile到线程寄存器(RF)的映射过程。编号的箭头表示数据加载和处理的逻辑步骤。
GEMM操作的核心是cute::gemm,它将A和B的分片相乘并累加到累加器(accum)中。
cute::gemm(tiled_mma, tCrA_mma, tCrB_(_:_,_:k_block), accum);
tCrA_mma: 位于RF中,是线程本地的。accum: 同样位于RF中,是线程本地的,精度为FP32。((_2, _2, _2), _1, _1)((_2, _2, _2), _1)((_1, _2, _4), _0)下图展示了Warp0的累加器寄存器布局,其中每个64位寄存器存储2个FP32值。
scaling_B & add)。这种方法显著减少了缩放操作的次数。Weight Scaling Group Size 设置为 128,这等于 4 x GMMA_K。scale 0, scale 1)可以应用于累加器中的不同区域。基于上述优化方案,流水线被重新设计。
Scaling阶段: 在流水线中显式加入一个阶段,对FP32中间结果应用B的缩放因子并累加。这个新的流水线通过在计算过程中间引入缩放步骤,提高了效率。warpgroup_wait的依赖关系也相应调整为3 -> 0。