Xin Yao, DevTech | AI Open Day/May 30, 2025
加速计算密集型操作
减少内存占用
加速通信
缩小训练与推理之间的差距
许多工作可以被复用
来自业界的成功案例
开放计算项目 8 位浮点规范 (Open Compute Project 8-bit Floating Point Specification)*
E4M3
torch.float8_e4m3fnE5M2
torch.float8_e5m2下图展示了标准的 BF16 混合精度训练流程,其中权重更新在 FP32 中进行以保持精度。
在 BF16 混合精度训练的基础上引入 FP8,以进一步加速计算密集型的 GEMM 操作。
缩放粒度 (Scaling granularity)
缩放方法 (Scaling method)
s * x 的操作将张量的值范围“移动”到与 E4M3 可表示范围有更好重叠的区域。scale = FP8_MAX / amax(window)Xq = static_cast<fp8_type>(x * scale)X' = Xq / scale什么是 FP8 方案?
使用哪种 FP8 格式?
使用何种缩放粒度?
模型的哪一部分使用 FP8?
模型的 FP8 部分:
当前/动态/在线 (Current/Dynamic/Online/Live) 缩放:
amax 值。延迟 (Delayed) 缩放:
amax 值(近似值)。TE (v2.2 及更高版本)
with fp8_autocast(fp8_recipe=Float8CurrentScaling()):
model()
--fp8-format hybrid
--fp8-recipe tensorwise
公式:
Y[m, n] = W[n, k] @ X[m, k]dX[m, k] = W^T[k, n] @ dY[m, n]dW[n, k] = X^T[k, m] @ dY^T[n, m]流程:
公式:
Y[m, n] = W[n, k] @ X[m, k]dX[m, k] = W^T[k, n] @ dY[m, n]dW[n, k] = X^T[k, m] @ dY^T[n, m]流程:
缩放粒度: 按分块/组/块 (Per-tile/group/block)。
FP8 格式: E4M3。
模型的 FP8 部分:
TE (v2.3 及更高版本)
with fp8_autocast(fp8_recipe=Float8BlockScaling()):
model()
--fp8-format e4m3 --fp8-recipe blockwise
公式:
Y[m, n][BF16] = W[n, k][128x128] @ X[m, k][1x128]dX[m, k][BF16] = W^T[k, n][128x128] @ dY[m, n][1x128]dW[n, k][FP32] = X^T[k, m][1x128] @ dY^T[n, m][1x128]流程:
MXFP8 是 OCP (Open Compute Project) 的微缩放格式 (Microscaling Formats) 之一,专为 Blackwell 架构设计。
模型中采用 FP8 的部分:
Transformer Engine (TE) 集成 (v2.0 及更高版本):
通过 fp8_autocast(fp8_recipe=MXFP8BlockScaling()) 与模型进行集成。
--fp8-format e4m3--fp8-recipe mxfp8下图展示了 MXFP8 的数据结构,其中 k 个标量元素共享一个缩放因子 X。每个元素 P_i 有 d 个比特。表格总结了 MXFP8 的格式细节。
在仅支持 Blackwell 架构的 MXFP8 训练中,激活值和权重在前向和后向传播中以不同方式进行量化。这要求数据同时支持行优先(rowwise)和列优先(colwise)的格式。
下图详细描述了 MXFP8 在前向传播、后向传播和优化器更新过程中的数据流:
为了优化内存使用,可以调整 FP8 的存储策略。
一个核心问题是:我们能否只保留 FP8 权重,以进一步节省内存?
下图展示了此存储策略下的数据流,其中虚线框部分不被存储。
仅存储 FP8 主权重会带来一些挑战,尤其是在与 MCore 的分布式优化器(如 ZeRO-1)和分布式检查点(checkpoint)协同工作时。
以下是跨 DP Rank(数据并行等级)同步权重的流程:
下图详细描绘了在 DP=2,Rank 0 和 Rank 1 之间的这个同步过程。
在获取全局 amax 值之后,流程继续进行:
此流程需要为每种 FP8 配方(recipe)进行单独的实现。在 MCore (v0.12 及更高版本) 中,可以通过 --fp8-param-gather 参数来启用此功能。
下图展示了从全局 amax 到最终 FP8 模型权重的转换和分发过程。
下表总结了不同训练精度和存储策略下的内存占用情况,此分析不包括激活值所占用的内存,这部分内存在 MoE 模型中尤为重要。
store_param_remainders 默认开启,FP32 主权重可以与模型权重共享前 16 个比特。因此,主权重只需要额外的 16 比特。从表中可以看出,使用 FP8 主权重(FP8 Primary Weights)可以显著减少内存占用,尤其是在 ZeRO-1 优化下(对比 BF16 的 6+10/d,FP8 Primary Weights 在 Blackwell 上可以达到 5+12/d)。
DP (数据并行):
TP (张量并行):
下图对比了 split_overlap_rs 和 atomic_gemm_overlap_rs (atomic) 两种方式。atomic_gemm 通过将多个小 GEMM 操作合并为一个原子操作,并使用计数器进行同步,减少了计算流(compute stream)和通信流(comm stream)之间的等待事件,从而改善了计算和通信的重叠效率。
下图展示了 FP8 A2A(All-to-All)的数据流:
Quantize into FP8 → FP8 A2A → FP8 Permute → [GroupedGEMM, BF16 output]
Quantize into FP8 → FP8 A2A → FP8 PermuteGroupedGEMM, BF16 output(FP8 Permute Kernels,GroupedLinear 的 FP8 输入)下图展示了在不同硬件平台(GB200, B200, H100)上,使用 FP8-CS (常规缩放) 和 FP8-MX (微缩放) 相对于 BF16 的训练加速比。
下表展示了 DeepSeek-V3 训练的详细配置和性能数据:
未来的性能优化方向包括:
下图的性能剖析图显示,在路由(router)层有“过多的 kernels”(Too many kernels),导致了显著的 TP2 开销和 TE 开销,凸显了 Kernel 融合的必要性。
下图比较了 BF16 和 FP8 训练的损失曲线。结果使用系数为 0.9 的指数移动平均(EMA)进行了平滑处理。从图中可以看出,FP8 的损失曲线与 BF16 高度吻合,表明其具有良好的收敛性。
图表来源:DeepSeek-V3 技术报告。
下图展示了在 MMLU 和平均常识理解任务上,FP8 与 BF16 的训练收敛情况对比。结果显示,两种精度格式下的模型准确率曲线非常接近。
我们的经验
没有统计标准可以表明收敛。
下图展示了在 MMLU 和平均常识理解任务上,BF16 与 FP8 的进一步比较,图中还区分了 "Per-tensor current scaling" 和 "Blockwise scaling" 两种缩放方式。实验结果表明,FP8 和 BF16 的性能表现依然非常接近。