ACCELERATING BACKWARD DATA GRADIENT BY INCREASING TENSOR CORE UTILIZATION IN CUTLASS

Manish Gupta, March 21, 2021

目录


致谢

幻灯片中包含致谢页,感谢了CUTLASS GitHub社区、CUTLASS团队的开发者和产品管理人员,以及众多贡献者。

议程

本次演讲的议程包括:
* 概述:介绍CUTLASS 2.6 (GTC March 2021) 到 CUTLASS 2.9 (GTC March 2022) 的发展。
* 卷积性能 (CUTLASS 2.9):分析 Fprop, Dgrad 和 Wgrad 在 CUDA 11.6 上的性能。
* 跨步 Dgrad:探讨朴素的跨步 Dgrad(包含冗余 MMAs)和 CUTLASS 跨步 Dgrad(移除冗余 MMAs)。
* 隐式 Gemm 卷积:构建连贯完整的抽象,高效地组合复杂算法。
* 使用 CUTLASS 完成更多任务:介绍使用 Tensor Cores (3xTF32) 加速单精度算术运算,分组 GEMM,以及 CUTLASS Python 示例。

概述

CUTLASS 是用于深度学习和线性代数的 CUDA C++ 模板库。该库在不同GTC年份和CUDA版本下持续演进,并增加了新功能。

CUTLASS版本和功能时间线
Page 5

CUTLASS 的新进展 (自2021年7月起)

CUTLASS 卷积性能

CUTLASS 2.8 在 NVIDIA A100 (CUDA 11.6) 上混合精度训练 (F16 <- F16 * F16 + F32) 相对于 cuDNN 的性能表现:

CUTLASS 2.8 卷积性能与cuDNN的比较
Page 7

图表显示了 ResNet50 不同层(2到20)的 fpropdgradwgrad 相对于 cuDNN 的加速比。

CUTLASS 跨步反向数据梯度 (DGRAD)

CUTLASS 在处理跨步 Dgrad 方面取得了显著进展。CUTLASS 2.8 的跨步 Dgrad 性能相较于 CUTLASS 2.5 有了巨大提升。

CUTLASS 2.5 vs CUTLASS 2.8 跨步Dgrad性能对比
Page 9

该图比较了 CUTLASS 2.5 和 CUTLASS 2.8 在 NVIDIA GA100 (1005Mhz) 上,使用 TensorOp (F16 <- F16*F16+F32) 进行跨步 Dgrad (步长为 2x2) 相对于 cuDNN 8.3.21 (CUDA 11.6 Toolkit) 的性能。
* CUTLASS 2.5 (黄色条) 的性能远低于 cuDNN,在所有测试层中都表现不佳,最低只有 20% 左右。
* CUTLASS 2.8 (绿色条) 则显示了显著的性能提升,在许多层中超过了 cuDNN,最高达到约 135% 的加速比 (例如层10和层14)。
* 测试的跨步层包括:
* 层 1-6:ResNet50,1x1 滤波器
* 层 7-9:RNXT JoC,3x3 滤波器
* 层 10-15:MaskRCNN MLPERF,1x1 滤波器

跨步 DGRAD (1D 理解)

为了更好地理解跨步 DGRAD,我们首先回顾 1D 卷积的基本概念。

前向传播 | Y = FPROP(X, W)
对于步长为 1,滤波器大小为 3 的 1D 卷积:

1D 步长1,滤波器3的前向传播
Page 11

图示展示了激活 (x) 和滤波器 (w) 如何通过乘积累加操作生成输出 (y) 的过程。每个输出元素 (y) 由激活 (x) 和滤波器 (w) 的乘积累加贡献。

反向数据梯度 | DX = DGRAD(DY, W)
对于步长为 1,滤波器大小为 3 的 1D 卷积的反向数据梯度:

1D 步长1,滤波器3的反向数据梯度
Page 12

图示展示了输出梯度 (Dy) 和滤波器 (w) 如何通过乘积累加操作生成激活梯度 (Dx) 的过程。每个激活梯度元素 (Dx) 由输出梯度 (Dy) 和滤波器 (w) 的乘积累加贡献。

跨步前向传播 | Y = FPROP(X, W)
对于步长为 2,滤波器大小为 3 的 1D 卷积:

1D 步长2,滤波器3的跨步前向传播
Page 13

图示展示了当步长增加到 2 时,激活 (x) 和滤波器 (w) 如何通过乘积累加操作生成输出 (y)。此时,输出 (y) 中的元素间隔更远。

跨步反向数据梯度 | DX = DGRAD(DY, W)
对于步长为 2,滤波器大小为 3 的 1D 卷积的反向数据梯度:

1D 步长2,滤波器3的跨步反向数据梯度
Page 14

图示展示了当步长为 2 时,输出梯度 (Dy) 和滤波器 (w) 如何通过乘积累加操作生成激活梯度 (Dx)。

反向数据梯度 (DGRAD) [跨步] 朴素实现
对于步长为 2,滤波器大小为 3 的 1D 卷积的朴素实现:

1D 步长2,滤波器3的跨步反向数据梯度朴素实现
Page 15

朴素实现中,为了处理跨步情况,会在滤波器贡献中引入零填充 (用 'Z' 表示)。这可能导致冗余的乘积累加操作。每个激活梯度元素 (Dx) 仍由输出梯度 (Dy) 和滤波器 (w) 的乘积累加贡献。

跨步前向传播与跨步反向传播 (STRIDED FPROP VS. STRIDED DGRAD)

1D 卷积(步长 = 2,滤波器 = 3)

本节对比了步长为 2、滤波器大小为 3 的 1D 卷积中的跨步前向传播 (STRIDED FPROP) 和跨步梯度反向传播 (STRIDED DGRAD)。

STRIDED FPROP VS. STRIDED DGRAD
Page 16

其中 CTA 代表 Cooperative Thread Array 或 Thread block。

朴素跨步反向传播与 CUTLASS 跨步反向传播 (NAÏVE STRIDED DGRAD VS CUTLASS STRIDED DGRAD)

1D 卷积(步长 = 2,滤波器 = 3)

本节对比了朴素跨步反向传播(CUTLASS 2.5)和 CUTLASS 优化的跨步反向传播(CUTLASS 2.6)在 1D 卷积(步长 = 2,滤波器 = 3)中的实现。

NAÏVE STRIDED DGRAD VS CUTLASS STRIDED DGRAD
Page 17

CUTLASS 跨步反向传播 (CUTLASS STRIDED DGRAD)

1D 卷积(步长 = 2,滤波器 = 3)

CUTLASS 跨步反向传播方法通过以下三个阶段进行优化:

CUTLASS STRIDED DGRAD
Page 18

CUTLASS 跨步反向传播 [小滤波器] (CUTLASS STRIDED DGRAD [SMALL FILTER])

1D 卷积(步长 = 2,滤波器 = 1)| 泛化(步长 > 滤波器)

本节比较了在小滤波器情况下(步长 = 2,滤波器 = 1)以及步长大于滤波器大小的泛化场景下,朴素跨步反向传播 (CUTLASS 2.5) 和 CUTLASS 跨步反向传播 (CUTLASS 2.6) 的表现。

CUTLASS STRIDED DGRAD [SMALL FILTER]
Page 19

CUTLASS 跨步反向传播 (CUTLASS STRIDED DGRAD)

(步长 > 1)

当步长大于 1 时,CUTLASS 跨步反向传播的伪代码结构如下:

// Prologue
id -> mapped_id // 将 CTA 映射到计算有效的 Dx 元素

// Mainloop
accumulators = 0 // 清零累加器
if (isMainloopRequired(blockIdx.x)) {
    accumulators = Dy * w
}

// Epilogue
id -> mapped_id // 映射回目标 Dx 元素位置
Dx_source = ((beta==0) ? 0 : Dx[mapped_id])
Dx[mapped_id] = alpha * (accumulators) + beta * Dx_source // 存储 Dx

CUTLASS STRIDED DGRAD (Stride > 1)
Page 20

2D 中的跨步反向传播 (STRIDED DGRAD IN 2D)

(隐式 GEMM 卷积)

本部分将探讨 2D 卷积中跨步反向传播的实现,特别是通过隐式 GEMM (General Matrix Multiply) 卷积的方式。

STRIDED DGRAD IN 2D (IMPLICIT GEMM CONVOLUTIONS)
Page 21

4D 张量上的 2D 反向传播梯度 - 定义 (2D DGRAD ON 4D TENSORS - DEFINITION)

反向数据传播 | Dx = CONV (Dy, w)

2D 反向传播梯度(DGRAD)在 4D 张量上的定义如下:

$$ \mathbf{Dx}[n, h, w, c] = \sum_{k=0}^{K-1} \sum_{r=0}^{R-1} \sum_{s=0}^{S-1} (\mathbf{Dy}[n, \bar{p}(h,r), \bar{q}(w,s), k] * \mathbf{w}[k,r,s,c]) $$

其中:

$$ \bar{p}(h,r) = (h + \text{pad\_h} - r * \text{dilation\_h}) / \text{stride\_h} $$

$$ \bar{q}(w,s) = (w + \text{pad\_w} - s * \text{dilation\_w}) / \text{stride\_w} $$

2D DGRAD ON 4D TENSORS - DEFINITION
Page 22

隐式 GEMM 卷积 (IMPLICIT GEMM CONVOLUTION)

反向数据梯度 (Dgrad)

隐式 GEMM 卷积通过将 4D 张量映射到 2D 矩阵来执行 Dgrad 计算:

IMPLICIT GEMM CONVOLUTION
Page 23

朴素跨步反向传播 (隐式 GEMM) - CUTLASS 2.5 (NAÏVE STRIDED DGRAD (IMPLICIT GEMM) - CUTLASS 2.5)

为了覆盖整个 GEMM-K (KRS) 维度:

计算 gemm_k_iterations 的公式为:
gemm_k_iterations = R * S * ((K + Tile_K - 1) / Tile_K)

示例尺寸:
- Tile_M = 128
- Tile_N = 128
- Tile_K = 32
- R-by-S = 3-by-3
- Stride = 2-by-2

NAÏVE STRIDED DGRAD (IMPLICIT GEMM) - CUTLASS 2.5
Page 24

CUTLASS 跨步反向传播 (隐式 GEMM) - CUTLASS 2.6+ (CUTLASS STRIDED DGRAD (IMPLICIT GEMM) - CUTLASS 2.6+)

CUTLASS 2.6+ 版本的隐式 GEMM 跨步反向传播优化如下:

示例尺寸:
- Tile_M = 128
- Tile_N = 128
- Tile_K = 32
- R-by-S = 3-by-3
- Stride = 2-by-2

CUTLASS STRIDED DGRAD (IMPLICIT GEMM) - CUTLASS 2.6+
Page 25

实现细节与算法步骤

本节继续深入探讨隐式 GEMM 卷积的实现细节。

隐式 GEMM 卷积算法的主要步骤如下:

  1. Dyw 矩阵的瓦片加载到共享内存中。
  2. 在共享内存的操作数上计算矩阵乘加 (mma)。
  3. 迭代遍历 KRS 维度。

IMPLICIT GEMM CONVOLUTION
Page 27

加载输出梯度矩阵 (Dy) 和滤波器矩阵 (w) 的瓦片

算法的第一步是加载 Dyw 矩阵的瓦片到共享内存中。

示例:
- Tile_M = 128
- Tile_N = 128
- Tile_K = 32
- 输入类型 = F16

IMPLICIT GEMM CONVOLUTION
Page 28

详细加载过程:

示例尺寸与 Page 28 相同。

IMPLICIT GEMM CONVOLUTION
Page 29

计算矩阵乘积累加

在隐式GEMM卷积中,warp级的矩阵乘积累加(MMA)操作器直接使用 cutlass::gemm::warp 模块来利用NVIDIA Tensor Cores。

隐式GEMM卷积的MMA操作示意图
隐式GEMM卷积的MMA操作示意图

上图展示了示例瓦片大小:Tile_M = 128Tile_N = 128Tile_K = 32。数据从全局内存(Global Memory)通过共享内存(Shared Memory)和寄存器文件(Register Files)流向Tensor Cores,使用 cutlass::gemm::warp::MmaTensorOp 进行处理。

迭代 KRS 维度

隐式GEMM卷积算法包含以下步骤:

  1. 将Dy(输出梯度矩阵)和滤波器w的瓦片加载到共享内存中。
  2. 在共享内存中的操作数上计算矩阵乘积累加(mma)。
  3. 遍历KRS维度:
    a) 推进以在共享内存中加载下一个瓦片。
    b) 确保所有滤波器位置 (r,s) 和输出通道K的累加。

示例瓦片大小:Tile_M = 128Tile_N = 128Tile_K = 32。输入类型为F16。

隐式GEMM卷积迭代KRS维度的示意图
隐式GEMM卷积迭代KRS维度的示意图

上图说明了滤波器矩阵 (w)、输出梯度矩阵 (Dy) 和激活梯度矩阵 (Dx) 的尺寸以及瓦片大小。

为了覆盖整个GEMM-K (KRS) 维度:
通过遍历以下内容,在KRS维度中处理瓦片:
* 滤波器 s 位置
* 滤波器 r 位置
* Tile_K 滤波器 k 元素

启动足够的迭代次数以覆盖所有输出通道元素 (K) 和滤波器位置 (R-by-S):
num_tiled_iterations = R * S * ((K + Tile_K – 1)/Tile_K)

示例大小:Tile_M = 128Tile_N = 128Tile_K = 32R-by-S = 3-by-3Stride = 1-by-1

覆盖GEMM-K维度的迭代策略
覆盖GEMM-K维度的迭代策略

上图展示了如何通过瓦片迭代来覆盖3x3滤波器位置和0..K-1通道元素。

构建连贯完整的抽象

cutlass::conv::threadblock::Iterators
CUTLASS卷积迭代器实现了以下抽象:
* advance(): 移动到GEMM-K中的下一个瓦片迭代。
* operator++(): 移动到线程的下一个加载位置。
* at(): 应用 p (h,r)q (w,s) 函数将 hw 映射到 pq,并返回Dy张量 [n,p,q,k] 中的坐标。
* valid(): 检查全局内存中张量的越界访问。
* get(): 根据张量坐标从全局内存中获取指针。

CUTLASS卷积迭代器的抽象示意图
CUTLASS卷积迭代器的抽象示意图

上图详细说明了这些迭代器函数在遍历输出梯度矩阵 (Dy) 和滤波器矩阵 (w) 时的作用。

高效地组合复杂算法的抽象

幻灯片强调了抽象在高效组合复杂算法中的重要性。

隐式GEMM卷积 - CUTLASS 2.5 组件

CUTLASS 2.5 隐式GEMM卷积的组件包括 cutlass::conv::threadblockcutlass::gemm::warpcutlass::epilogue
数据流经全局内存 -> 共享内存 -> 寄存器文件 / CUDA/Tensor Cores -> SMEM -> CUDA Cores -> 全局内存。
每个阶段都定义了其布局(Layout)、迭代器(IteratorsA/B)和矩阵乘法累加(Mma)操作。

CUTLASS 2.5 隐式GEMM卷积组件架构
CUTLASS 2.5 隐式GEMM卷积组件架构

上图展示了CUTLASS 2.5的组件架构及其数据流,以及一个 cutlass::conv::kernel::ImplicitGemmConvolution 的代码片段。

隐式GEMM卷积 - STRIDED DGRAD (CUTLASS 2.6+ 组件)

CUTLASS 2.6+ 对步进(strided)DGRAD隐式GEMM卷积的组件进行了更新。主要更新体现在 IteratorsA/B 中使用了 PredicatedTileIteratorStridedDgrad,并且内核名称变更为 cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad

CUTLASS 2.6+ 步进DGRAD隐式GEMM卷积组件架构
CUTLASS 2.6+ 步进DGRAD隐式GEMM卷积组件架构

上图展示了CUTLASS 2.6+的更新组件架构和对应的代码片段。

更多地利用 CUTLASS

幻灯片鼓励用户更深入地利用CUTLASS库。

使用 Tensor Cores 加速单精度计算

3xTF32 在A100上以48 TFLOPs加速单精度算术运算:
* 性能比峰值单精度高出2倍以上。
* 精度优于单精度(不符合IEEE标准)。
* 支持实数和复数数据类型。
* GEMM和卷积的示例实现。

3xTF32在NVIDIA A100上的性能和精度表现
3xTF32在NVIDIA A100上的性能和精度表现

上图显示了3xTF32在NVIDIA A100上的性能(TFLOPs)和不同数据类型(实数、复数)的相对误差,与GEMM-K维度的关系。

分组 GEMM (GROUPED GEMM)

具有独特问题大小的批处理GEMM。
CUTLASS支持在NVIDIA Volta、Turing和Ampere架构的Tensor Cores上高效实现。
* “持久化内核”启动足够的线程块以完美填充GPU。
* 外部循环和“调度器”将线程块映射到输出问题的一个瓦片。
* 性能超越批处理GEMM,并且不需要预处理:
* 混合专家(Mixture of Experts)自然语言模型的端到端性能提升1.75倍。
* 对于32到4096之间随机大小的几何平均加速比为1.39倍。

CUTLASS分组GEMM与批处理GEMM的加速比
CUTLASS分组GEMM与批处理GEMM的加速比

上图展示了CUTLASS分组GEMM相对于批处理GEMM在A100 Tensor Cores (F16 * F16 + F32) 上的加速比,与计算强度(flops/byte)的关系。

幻灯片还提供了一个关于使用 problem_visitor 迭代瓦片的CUDA C++代码片段。

CUTLASS Python

使用CUDA Python动态编译CUTLASS。
CUDA Python向Python程序员公开了CUDA驱动API和NVRTC编译器。
CUTLASS使用基于Python的IR来生成设备端的CONV和GEMM操作符。
一个新的主机端运行时组件支持从Python JIT编译和启动CUTLASS GEMM内核。

CUTLASS Python API示例
CUTLASS Python API示例

上图展示了如何使用CUTLASS Python API进行操作清单、构建SGEMM操作、初始化SGEMM对象、打包参数、规划CUDA网格启动以及启动内核的代码示例。
CUTLASS 2.9即将发布。

结论

最终结论总结了CUTLASS的各项优势和进展:

参考文献