Enable Tensor Core Programming in Python with CUTLASS 4.0

Albert Di, Vincent Zhang | 2025-05-30

目录

  1. Python DSL 概述

    • 为何转向 Python
    • DSL 基础架构
  2. Hopper GEMM 实现

    • GEMM 概述
    • MMA/Copy 原子操作
    • Hopper GEMM 代码演示
    • 性能比较
  3. 结论

Python DSL 概述

为何转向 Python

使用 C++ 的主要痛点

C++ 模板及其带来的不良后果:

Page 4
Page 4

推出 CUTLASS 4.0

在 Python 中进行张量核心(Tensor core)编程。

Page 5
Page 5

CUTLASS in Python

初始版本包含 CuTe。

CUTLASS in Python 架构 (Page 6)
CUTLASS in Python 架构 (Page 6)

架构特点:

为什么选择 Python DSL?

CuTe DSL 的主要优势。

使用 Python DSL 的优势 (Page 7)
使用 Python DSL 的优势 (Page 7)

从上图可以看出,Python 的编译时间(241 ms)远低于 C++(27997 ms),提升超过100倍。

CUTLASS in Python 入门

如何开始使用:

CUTLASS in Python 入门示例 (Page 8)
CUTLASS in Python 入门示例 (Page 8)
  1. 通过 pip install nvidia-cutlass-dsl 安装。
  2. 编写 Python 内核代码,使用 @cute.kernel@cute.jit 装饰器。
  3. 通过 python3 hello_world.py 运行,代码将被即时编译(JIT)并执行。
import cutlass
import cutlass.cute as cute

@cute.kernel
def kernel():
    tidx, _, _ = cute.arch.thread_idx()
    if tidx == 0:
        cute.printf("Hello world")

@cute.jit
def hello_world():
    cutlass.cuda.initialize_cuda_context()

    # Launch kernel
    kernel().launch(
        grid=(1, 1, 1),
        block=(32, 1, 1)
    )

# Just-In-Time (JIT) compilation
print("Running hello_world()...")
hello_world()

DSL 基础架构

CUTLASS Python 架构

CUTLASS Python 架构 (Page 10)
CUTLASS Python 架构 (Page 10)

上图展示了从 Python 端代码到最终在 GPU 上执行的 CUBIN 文件的编译流程:
1. 用户使用 @cute.kernel 编写 Python 内核。
2. 代码通过 CUTLASS DSL 栈(CuTe DSL, DSL 编译器)转换成中间表示(IR program)。
3. IR 程序进入 CUDA 编译器栈,依次通过 NVVM/LLVM、PTX 和 SASS 编译。
4. 最终生成的 CUBIN 由 JIT Executor 加载并执行。

在 Python 中编写内核

通过 @cute.jit@cute.kernel 装饰器,可以将 C++ 中复杂的模板元编程内核,用更简洁的 Python 代码来表达。

在 Python 中编写内核 (Page 11)
在 Python 中编写内核 (Page 11)

左侧是 C++ CUTLASS 内核的定义,使用了大量模板参数。右侧是等效的 Python 实现,通过函数参数和类型注解来定义,更加清晰直观。

与 PyTorch 的轻松集成

支持 DLPack 协议

可以直接将 torch.tensor 作为输入传递给 JIT 编译的 CUTLASS 内核函数,无需手动数据转换。

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch

@cute.kernel
def jit_kernel(A: cute.Tensor):
    ...

@cute.jit
def jit_func(A: cute.Tensor):
    jit_kernel(
        A, config=cutlass.LaunchConfig(grid=[1, 1, 1], block=[1, 1, 1])
    )

# Create a torch tensor
A_tensor = torch.tensor([0, 0], dtype=torch.int32).cuda()

# Call the JIT function with the torch tensor
jit_func(A_tensor)

# Or explicitly convert using from_dlpack
# Or jit_func(from_dlpack(A_tensor).mark_layout_dynamic())

简化测试流程

与 Python 生态(如 PyTorch)的深度集成,使得验证和测试变得非常简单。可以直接使用 PyTorch 实现一个参考内核,并用 torch.testing.assert_close 来验证 CUTLASS 内核的正确性。

与 PyTorch 集成进行测试 (Page 13)
与 PyTorch 集成进行测试 (Page 13)

上图左侧是 C++ 中复杂的测试流程,包括设备内存分配、内核启动、同步、数据拷贝回主机以及手动比较。右侧则展示了等效的 Python 测试流程,仅需两行代码即可完成计算和验证。

从静态布局到动态布局

静态布局 (Static Layout):

当 JIT 函数接收到不同形状(layout)的张量时,会为每种形状编译一个专门的内核。

静态布局处理 (Page 14)
静态布局处理 (Page 14)

如上图所示,当 jit_func 分别接收形状为 (3:1)A_tensor 和形状为 (5:1)B_tensor 时,会编译两个独立的 JIT 函数,每个函数对应一种静态布局。

动态布局 (Dynamic Layout):

通过使用 mark_layout_dynamic(),可以生成一个通用的内核,处理动态布局的张量,从而避免为每种形状都重新编译。

动态布局处理 (Page 15)
动态布局处理 (Page 15)

如上图所示,通过 mark_layout_dynamic(mode=0),即使输入不同形状的张量,也只会编译一套 JIT 函数。

与 Pytorch 的轻松集成:LLaMA 8b 集成示例

Page 16: LLaMA MLP模块代码修改示例
Page 16: LLaMA MLP模块代码修改示例
Page 17: 自定义CUTLASS线性模块实现代码
Page 17: 自定义CUTLASS线性模块实现代码

更好的代码表达性和可读性

通过 CuTe,可以实现更简洁和可读性更高的代码,尤其是在进行操作融合(如 GEMM 与激活函数融合)时。

Page 18: 代码可读性对比与TensorSSA介绍
Page 18: 代码可读性对比与TensorSSA介绍

自定义类 C 结构体的数据类型

@cute.struct 装饰器允许开发者像在 C 语言中一样定义具有精确内存布局的数据结构。

Page 19: @cute.struct 示例与内存布局
Page 19: @cute.struct 示例与内存布局

以面向对象编程(OOP)方式编写核函数

@cute.struct 同样有助于以更结构化和面向对象的方式组织核函数代码,特别是共享内存的管理。

Page 20: 传统方式与OOP方式编写核函数的对比
Page 20: 传统方式与OOP方式编写核函数的对比

通过缓存降低核函数启动延迟

在不使用缓存的情况下,即时编译(JIT)会带来显著的开销。

Page 21: 无缓存时JIT编译带来的开销
Page 21: 无缓存时JIT编译带来的开销

通过缓存降低核函数启动延迟:零编译方案

通过缓存已编译的 CUBIN(CUDA Binary),可以实现 "零编译"(Zero Compile),从而消除 JIT 开销。

Page 22: 使用缓存避免重复编译的代码实现
Page 22: 使用缓存避免重复编译的代码实现

生成、编译和启动开销对比

下图展示了在 Blackwell B100 GPU 上运行 FP16 GEMM (M=N=K=8K) 时,不同缓存策略下的开销对比。

Page 23: 不同缓存策略的开销对比图
Page 23: 不同缓存策略的开销对比图

如果缺少功能怎么办?无缝集成原生 Op 构建器

当上层 API 缺少特定功能时,可以直接使用底层的原生操作(Op)构建器。

Page 24: 集成原生Op构建器的代码示例
Page 24: 集成原生Op构建器的代码示例

Hopper GEMM 实现

GEMM 概述

Page 26: GEMM分块示意图
Page 26: GEMM分块示意图

Hopper GEMM 流水线:单个分块的视角

对于 C 矩阵的每个分块,GEMM 的计算遵循一个三阶段流水线:

  1. 序言 (Prologue): 使用 TMA (Tensor Memory Accelerator) Load 指令,将 A 和 B 矩阵的第一个分块从全局内存(GMEM)预取到共享内存(SMEM)中。
  2. 主循环 (Mainloop): 使用 Tensor Core 计算 C = A*B + C,同时使用 TMA Load 指令加载下一个 K 维度的 A/B 矩阵分块,实现了计算与数据加载的重叠。
  3. 尾声 (Epilogue): 使用 TMA Store 指令,将最终计算完成的 C 矩阵分块从共享内存(SMEM)写回到全局内存(GMEM)。
Page 27: Hopper GEMM 单分块流水线示意图
Page 27: Hopper GEMM 单分块流水线示意图

MMA/Copy 原子操作

CuTe 的设计理念

对于已经使用 CUTLASS-C++ 的开发者来说,CuTe 的概念会非常熟悉。

Page 29: CuTe的抽象层次和设计理念
Page 29: CuTe的抽象层次和设计理念

cute.gemm

cute.gemm 是执行矩阵乘法的高级接口。

Page 30: cute.gemm 使用流程和图示
Page 30: cute.gemm 使用流程和图示

下图展示了使用 cute.copycute.gemm 的通用矩阵乘法(GEMM)的数据流。

Page 32
Page 32

cute.copy

CuTe TMA Atoms 为 TMA (Tensor Memory Accelerator) 提供 PTX 和元数据。

TMA Copy 代码示例:

op = cute.nvgpu.cpasync.CopyBulkTensorTileG2S0p()

tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom(
    op,
    gmem_tensor,
    smem_layout,
    cta_tiler,
)

cute.copy(
    tma_atom_a,
    gmem_tensor,
    smem_tensor,
    tma_bar_ptr,
    mcast_mask,
)
Page 31
Page 31

Hopper GEMM 代码演示

Hopper 中的流水线 (Pipeline)

采用非 Warp 专用核(Non Warp Specialized Kernel)风格。

生产者:TMA 加载

消费者:数学计算

Page 34
Page 34

加载 A/B (从 GMEM 复制到 SMEM)

加载 A/B 的数据流和代码示例
加载 A/B 的数据流和代码示例

代码逻辑:
在 K 维度上循环,每个循环:
1. mainloop_pipeline.producer_acquire: 等待一个空的共享内存缓冲区。
2. cute.copy: 使用 TMA atom 将数据从全局内存复制到共享内存。
3. mainloop_pipeline.producer_commit: 提交写入操作。
4. mainloop_pipeline.producer_state.advance(): 更新流水线状态。

计算 C = A * B (GEMM)

计算 C = A * B 的数据流和代码示例
计算 C = A * B 的数据流和代码示例

代码逻辑:
在 K 维度上循环,每个循环:
1. mainloop_pipeline.consumer_wait: 等待 TMA 加载数据完成。
2. cute.nvgpu.wgmma.fence(): WGMMA (Warp Group MMA) 栅栏同步。
3. 在 k_blocks 上循环,调用 cute.gemm() 执行矩阵乘法。
4. cute.nvgpu.wgmma.commit_group()wait_group(): 提交并等待 WGMMA 操作完成。
5. mainloop_pipeline.consumer_release: 释放共享内存缓冲区。
6. mainloop_pipeline.consumer_read_state.advance(): 更新流水线状态。

存储 C (从 SMEM 复制到 GMEM)

存储 C 的数据流和代码示例
存储 C 的数据流和代码示例

代码逻辑:
在 Epilogue 阶段:
1. 将累加器中的数据复制到寄存器 tRS_rD
2. cute.make_fragment_like: 进行类型转换。
3. cute.copy: 将寄存器数据复制到共享内存 tRS_sD
4. cute.arch.barrier: 同步。
5. cute.copy: 使用 TMA atom 将共享内存中的数据异步写回全局内存。
6. cute.arch.cp_async_bulk_commit_group()wait_group(): 提交并等待异步复制操作完成。

性能比较

Hopper 性能:Python vs. C++ (FP16 I/O GEMM,M=N=8192)

下图比较了在 M=N=8192 的 GEMM-K 尺寸下,Python 和 C++ 实现的数学计算吞吐量(Math SOL%)。结果显示,C++ 的性能略高于 Python。

测试规格: H100* 80GB HBM3, 132SM GPC-1500MHz/DRAM 2619MHz 700W, 128x256x64 cooperative size, Swizzle size = 8.

Page 39
Page 39

Hopper 性能:Python vs. C++ (FP16 I/O GEMM,M=N=2048)

下图比较了在 M=N=2048 的 GEMM-K 尺寸下,Python 和 C++ 实现的性能。

测试规格: H100* 80GB HBM3, 132SM GPC-1500MHz/DRAM 2619MHz 700W, 128x256x64 cooperative size, Swizzle size = 8.

Page 40
Page 40

结论

CuTe DSL:一种基于 Python 的编程语言,使用 CuTe 语义对 Tensor Core 进行编程,以实现最佳性能。

Page 41
Page 41