Jet-RL: Enabling On-Policy FP8 Reinforcement Learning with Unified Training and Rollout Precision Flow
Jet-RL: Enabling On-Policy FP8 Reinforcement Learning with Unified Training and Rollout Precision Flow
作者/机构: Haocheng Xi (NVIDIA, UC Berkeley), Charlie Ruan (UC Berkeley), Peiyuan Liao (Stanford University), Yujun Lin (NVIDIA), Han Cai (NVIDIA), Yilong Zhao (UC Berkeley), Shuo Yang (UC Berkeley), Kurt Keutzer (UC Berkeley), Song Han (NVIDIA, MIT), Ligeng Zhu (NVIDIA)
A1 主要贡献
本文旨在解决强化学习(RL)训练流程中计算效率低下和资源密集的问题,特别是 rollout 阶段消耗了超过70%的总训练时间。FP8 量化提供了一个有前景的解决方案,通常采用的方法是在 rollout 阶段使用 FP8 精度以缓解瓶颈,同时在训练阶段保留 BF16 精度。
核心问题与研究目标:
本文首次对 FP8 RL 训练进行了全面研究,并指出被广泛采用的 BF16 训练 + FP8 rollout 策略存在严重问题。在长序列 rollout 生成和挑战性任务中,该策略会导致严重的训练不稳定性和灾难性的准确率崩溃。分析表明,这些问题源于该方法的“离策略(off-policy)”性质,它在训练和推理之间引入了显著的数值不匹配。
创新点与主要贡献 (Jet-RL):
为了解决上述局限性,本文提出了 Jet-RL,一个旨在实现稳健和稳定 RL 训练的 FP8 RL 训练框架。
1. 统一的 FP8 精度流:Jet-RL 的核心思想是为训练和 rollout 采用统一的 FP8 精度流。这消除了训练策略与 rollout 策略之间的不匹配(policy mismatch),确保了训练过程的“在策略(on-policy)”性质,从而保证了优化的稳定性,并避免了低效的步间校准(inter-step calibration)。
2. 性能与效率:Jet-RL 在各种模型和任务中实现了显著的 rollout 和端到端加速,同时保持了与 BF16 基线相当的收敛性和准确率。相比于 BF16-train-FP8-rollout 方法通常会带来超过5%的性能下降,Jet-RL 将这一差距缩小到了约1%。同时,Jet-RL 实现了高达33%的 rollout 阶段加速,41%的训练阶段加速,以及16%的端到端训练加速。
图表支持:
- 图1 展示了 Jet-RL 与其他方法在 RL 训练中的差异。Jet-RL 提出的统一精度流旨在同时兼顾性能和吞吐量。
- 图2 表明 Rollout 生成主导了 RL 训练的延迟。当 rollout 长度大于 8k 时,rollout 耗时将超过总延迟的75%,成为主要瓶颈。
- 图3 显示,简单的 BF16 训练 + FP8 Rollout 策略在 rollout 上下文长度增加时会失效。尽管在短序列上性能相似,但当 rollout 长度超过 8k 时,其性能迅速下降。
- 图4 指出,当模型已通过预训练获得强大的任务先验知识时,BF16-train-FP8-rollout 策略没有退化。然而,当应用于更难的推理任务或使用较弱的基础模型进行训练时,BF16 和 FP8 rollout 的训练曲线会迅速分歧。
A3 背景知识
2.1. 量化基础
量化定义。量化是将高精度张量映射到低精度张量的过程,旨在加速计算并减少内存占用。本文考虑将张量 $X$ 量化为目标数据格式,其最大可表示值为 $\Delta_{max}$。量化过程可定义为:$\hat{X}, S_X = Q(X)$,其中 $Q(\cdot)$ 是量化器。
$\hat{X} = \left\lfloor \frac{X}{S_X} \right\rfloor, \quad S_X = \frac{\max(|X|)}{\Delta_{\text{max}}}.$
此处,$\hat{X}$ 是量化后得到的低精度张量,$S_X$ 是缩放因子。本文重点关注使用 E4M3 格式【索引22, Fp8 formats for deep learning, 2022, arXiv】【索引23, Ocp 8-bit floating point specification (ofp8), 2023, Open Compute Project】的 FP8 量化,其最大可表示值为 $\Delta_{max} = 448$。
2.2. 线性层的量化
线性层及其计算。在量化大型语言模型时,线性层是主要的关注对象,因为它们是计算密集型算子。遵循以往的量化训练工作【索引24, Jetfire: Efficient and accurate transformer pretraining with int8 data flow and per-block quantization, 2024, arXiv】,线性层的输入和输出表示如下:
$\boldsymbol{Y} \in \mathbb{R}^{N \times D}, \quad \boldsymbol{X} \in \mathbb{R}^{N \times C}, \quad \boldsymbol{W} \in \mathbb{R}^{D \times C},$
其中,$Y$ 是输出,$X$ 是激活,$W$ 是权重。$N$ 是 token 数量,$D$ 是输出通道数,$C$ 是输入通道数。它们对应的梯度分别表示为 $\nabla Y$, $\nabla X$ 和 $\nabla W$,形状均相同。每个线性层包含三个通用矩阵乘法(GEMM)操作:前向传播(FProp)、计算权重梯度的反向传播(WGrad)以及计算激活梯度的反向传播(DGrad)。它们可以形式化地表示为:
* 前向传播 (FProp)
$\boldsymbol{Y} = \boldsymbol{X} \times \boldsymbol{W}^\top, \quad \text{where} \quad \mathbb{R}^{N \times D} = \mathbb{R}^{N \times C} \times \mathbb{R}^{C \times D}.$
- 反向传播 – 权重梯度 (WGrad)
$\nabla_{\boldsymbol{w}} = \nabla_{\boldsymbol{Y}}^\top \times \boldsymbol{X}, \quad \text{where} \quad \mathbb{R}^{D \times C} = \mathbb{R}^{D \times N} \times \mathbb{R}^{N \times C}.$
- 反向传播 – 输入梯度 (DGrad)
$\nabla_{\boldsymbol{x}} = \nabla_{\boldsymbol{Y}} \times \boldsymbol{W}, \quad \text{where} \quad \mathbb{R}^{N \times C} = \mathbb{R}^{N \times D} \times \mathbb{R}^{D \times C}.$
FP8 GEMM 的布局要求。为了进行 FP8 计算而量化这三个 GEMM 操作有特定的矩阵布局要求,因为当前的 FP8 TensorCore 硬件要求第一个操作数为行主序(Row-wise, R)存储,第二个操作数为列主序(Column-wise, C)存储。具体的布局要求总结在表1中。
Table 1 | 线性层中 FP8 GEMM 的布局要求。
2.3. 强化学习的工作负载
RL 训练中的四种模型。一个标准的强化学习(RL)训练流程,例如基于近端策略优化(PPO)【索引25, Proximal policy optimization algorithms, 2017】,通常包含四种不同的模型。Actor 模型是我们旨在训练的主要大型语言模型(LLM)。Reference 模型通常是 actor 的一个初始副本,用于计算 KL 散度,以正则化 actor 的更新并确保训练稳定性。Reward 模型提供标量奖励信号,通常通过学习来反映人类偏好或任务目标。最后,Critic 模型估计生成响应的价值(或质量),预测预期的未来奖励。
RL 训练的三个阶段。每个 RL 训练步骤可以分解为三个不同的阶段,每个阶段都是由三种基本的 LLM 工作负载复合而成:解码(decode)、预填充(prefill)和训练(training)。首先,在 Rollout 阶段,Actor 模型执行自回归解码过程,为给定的提示生成一个或多个响应。我们使用术语“rollout”来区分此生成阶段与简单的推理。其次,在 Evaluation 阶段,生成的响应被输入到其他模型中。Reference、Reward 和 Critic 模型各自执行一次前向传播,这构成了一个预填充工作负载,以计算它们各自的输出。第三,在 Update 阶段,这些指标被收集起来(通常通过广义优势估计(GAE)【索引26, Highdimensional continuous control using generalized advantage estimation, 2018】处理),然后 Actor 模型执行一个训练步骤,包括一次前向和一次反向传播,以更新其权重。
实现策略与瓶颈。这些阶段具有不同的计算特性,因此通常采用两种不同类型的系统来实现。Rollout 阶段通常由优化的推理引擎如 vLLM【索引27, Efficient Memory Management for Large Language Model Serving with PagedAttention, 2023】或 SGLang【索引28, Sglang: Efficient execution of structured language model programs, 2023】管理,因为它们对自回归工作负载进行了高度优化。相反,Evaluation 和 Update 阶段由 FSDP【索引29, Pytorch fsdp: Experiences on scaling fully sharded data parallel, 2023, Proc. VLDB Endow.】、MegatronLM【索引30, Megatronlm: Training multi-billion parameter language models using model parallelism, 2019, ArXiv】或 DeepSpeed【索引31, Deepspeed: System optimizations enable training deep learning models with over 100 billion parameters, 2020】等训练框架处理,这些框架在并行性方面提供了更大的灵活性。为了保持 RL 训练的“在策略(on-policy)”性质,更新后的 actor 权重必须在每个步骤后从训练框架传输到推理引擎。
Rollout 阶段是性能瓶颈。尽管推理引擎进行了优化,rollout 阶段仍然是一个关键的性能瓶颈。正如我们在 § 3.1 中所示,rollout 阶段的延迟随着生成响应的长度而扩展。因此,这个阶段逐渐主导了端到端的训练延迟,使其成为整个 RL 训练流程中最昂贵的部分。
A3 关键观察与动机
3.1. Rollout 是 RL 训练的瓶颈
Rollout 耗时占比分析。在图2中,我们分析了 RL 训练中每个组件的性能,并使用 Qwen3-8B-Base【索引32, Qwen3 technical report, 2025, arXiv】模型在 GSM8K【索引33, Training verifiers to solve math word problems, 2021, ArXiv】和 MATH【索引34, Measuring mathematical problem solving with the math dataset, 2021】数据集上进行了实验。为了理解训练时间如何随 rollout 长度扩展,我们将最大 rollout 长度从 1K 变化到 16K,并测量每个训练组件的延迟比例变化。我们观察到,当 rollout 长度超过 8K 时,仅 rollout 生成就占了总训练时间的70%以上。这一趋势与先前研究【索引9, Dapo: An open-source llm reinforcement learning system at scale, 2025】中报告的观察结果一致,凸显了加速 rollout 的迫切需求。
FP8 量化作为解决方案。FP8 量化是加速 rollout 的一个有前景的解决方案,原因有二。首先,它相对于 BF16 能提供理想的2倍加速。其次,大量研究表明 FP8 推理不会降低下游任务的性能。此外,它可以轻松集成到现有的 RL 训练流程中。现有研究提出,只需将 BF16 权重转换为 FP8,同时在训练阶段保留 BF16。在本文的其余部分,我们将这种简单策略称为 BF16-train-FP8-rollout,并讨论这种设计的局限性。
3.2. 带校准的 BF16-Train-FP8-Rollout 速度慢
校准开销问题。虽然大多数训练后量化(PTQ)方法【索引35, Awq: Activationaware weight quantization for on-device llm compression and acceleration, 2023, GetMobile Mob. Comput. Commun.】【索引36, Gptq: Accurate post-training quantization for generative pre-trained transformers, 2022, ArXiv】【索引37, Smoothquant: Accurate and efficient post-training quantization for large language models, 2022, ArXiv】是为离线部署设计的,但 RL 训练需要在训练 actor 和 rollout actor 之间频繁进行权重同步。昂贵且依赖数据的校准过程,即使对于小型的 8B LLM 也可能需要数十分钟,因此如果每一步同步都重复进行,开销将无法承受,这与加速的目标相冲突。一些框架如 SLIME【索引38, slime: An llm post-training framework for rl scaling, 2025】和 NeMo-RL【索引39, Fp8 accuracy nemorl documentation, 2024】提出直接将 BF16 权重转换为 FP8 而不进行校准,并声称准确性不受影响。然而,我们发现这个结论实际上非常脆弱。
3.3. 不带校准的 BF16-Train-FP8-Rollout 不稳定
破坏“在策略”假设。RL 训练的质量高度依赖于“在策略(on-policy)”假设,即智能体从遵循其当前策略收集的数据和经验中学习——在 LLM 训练中,这要求 rollout 和训练在给定相同提示时应产生相同的 logits。不满足这种一致性的 RL 训练在极端情况下可能导致发散【索引40, Defeating nondeterminism in llm inference, 2025】【索引41, Understanding and mitigating numerical sources of nondeterminism in llm inference, 2025】【索引42, Towards deterministic inference in sglang and reproducible rl training, 2025, LMSYS Org Blog】。直接将 rollout actor 量化为 FP8,同时在 BF16 中更新全精度 actor,显然破坏了这种一致性,并引入了显著的策略不匹配【索引43, Your efficient rl framework secretly brings you off-policy rl training, 2025】【索引44, Prosperity before collapse: How far can off-policy rl reach with stale data on llms?, 2025】【索引45, Stabilizing off-policy deep reinforcement learning from pixels, 2022】。
结论的脆弱性。尽管在某些情况和一些研究(如截断重要性采样 Truncated Importance Sampling (TIS))中,BF16-train-FP8-rollout 并未导致太多性能下降,但我们发现这一结论非常脆弱,并依赖于特定的数据集和模型设置。具体来说,我们观察到 BF16-train-FP8-rollout 在两种情况下更容易失败:长序列 rollout 生成和挑战性任务。
在长序列 rollout 生成中失败。我们观察到 BF16-train-FP8-rollout 的性能与 rollout 长度密切相关。我们在 MATH 数据集上使用 GRPO 训练和评估 Qwen2.5-7B【索引46, Qwen2.5 technical report, 2025】,将生成长度从 4K 变化到 16K,然后比较 BF16 训练和 BF16-train-FP8-rollout 的性能。如图3所示,虽然当 rollout 长度较小(<4K)时,BF16-train-FP8-rollout 的性能与 BF16 训练相当,但当 rollout 长度扩展到 8K 时,FP8 rollout 迅速与 BF16 训练产生分歧。当我们进一步将 rollout 长度增加到 16K 时,FP8 rollout 的准确率在仅仅训练20步后就崩溃了。
失败原因的假设(长序列)。我们推测这种性能下降源于在每个解码步骤中,rollout 和训练分布之间的差异累积。对于较短的生成长度,这些差异相对较小,甚至可能对训练有部分益处。然而,随着 rollout 长度的增加,rollout 和训练分布之间的累积差异变得巨大。这加剧了 FP8 rollout 下强化学习的“离策略”问题,导致显著的不稳定性和训练性能下降。
在挑战性任务中失败。我们观察到,当基础模型在目标任务上能力不强时,BF16-train-FP8-rollout 往往会失败。例如,如图4所示,在 GSM8K 上训练 Qwen3-8B 时,BF16-train-FP8-rollout 的训练曲线与 BF16 训练曲线紧密匹配,甚至收敛更快。然而,切换到 Qwen3-8B-Base 模型后——该模型未经过指令微调,需要同时学习推理模式和解题能力——BF16-train-FP8-rollout 很快就落后于 BF16 基线,并导致性能较差。在这种设置下,评估准确率在仅仅训练20步后就与 BF16 基线产生分歧。
失败原因的假设(挑战性任务)。我们推测,当模型强大且任务相对简单时,BF16-train-FP8-rollout 表现良好。在这种情况下,模型对其响应表现出高置信度,使其对 FP8 量化带来的微小数值扰动不那么敏感。相反,随着任务难度的增加和模型置信度的降低,量化引起的误差会显著扭曲 rollout 轨迹。这导致 FP8 rollout 和 BF16 训练之间的不匹配日益加剧,从而导致不稳定的优化和性能下降。本质上,BF16-train-FP8-rollout 仅对较容易的任务稳健,但在扩展到更难的任务时会面临收敛问题。考虑到 RL 的目标是帮助模型获得它们尚不具备的能力,BF16-train-FP8-rollout 不太可能足够,因为它在更难任务上的不稳定性阻碍了有效的学习和可扩展性。
核心问题。这些观察引出了一个关键问题:我们如何才能减轻 BF16-train-FP8-rollout 中训练与推理的差异,以在不同设置下实现具有竞争力和稳定性的性能?
A2 方法细节
4. Jet-RL:实现“在策略”的 FP8 RL 训练
问题根源与解决方案。我们发现 BF16-train-FP8-rollout 训练失败的根本原因是训练和 rollout 之间的精度流不一致,这使得 RL 训练实际上是“离策略”的。鉴于在 RL 中保持“在策略”一致性的至关重要性【索引47, Flashrl: 8bit rollouts, full power rl, 2025】,解决这个问题至关重要。为了缓解此问题,我们提出在训练和 rollout 之间强制执行一个统一的 FP8 精度流,确保 rollout 中的响应与训练一致,从而使其成为一个“在策略”过程。
精度传播的图模型。我们形式化地将量化精度在模型中的传播建模为一个有向图 $G = (V, \mathcal{E})$,如图5所示。一个节点 $v_i \in V$ 表示模型中的一个算子或权重,我们称之为算子节点或权重节点。当 $v$ 的输出作为 $v'$ 的输入时,连接一条有向边 $(v, v') \in \mathcal{E}$,以描述两个连接算子之间的张量传播。一条边表示了传输张量的精度以及张量的量化粒度(如果被量化)。
训练图与推理图。在训练中,我们有图 $G_{train}$,它可以进一步分为两个子图 $G_{fwd_{train}}$ 和 $G_{bwd_{train}}$。在 $G_{fwd_{train}}$ 中,边代表激活值;而在 $G_{bwd_{train}}$ 中,边代表梯度。它们共享相同的节点和拓扑结构,因为它们有相同的算子套件,但边的方向是相反的。这两个图通过代表为反向传播保存的激活值的边连接起来。对于推理引擎,我们有图 $G_{infer}$,在使用 BF16 推理时,其拓扑结构也与 $G_{fwd_{train}}$ 相同。在 FP8 rollout 场景中,权重节点及其边为 FP8 精度。
BF16-train-FP8-rollout 的不匹配问题。对于 BF16-train-FP8-rollout,训练图中的所有权重和激活都是 BF16 格式。然而,在其推理图 $G_{infer}$ 中,所有输入到线性层的边都被量化为 FP8。这使得 BF16-train-FP8-rollout 有两个不同的前向传播量化图:$G_{fwd_{train}}$ 和 $G_{infer}$。这导致了训练和 rollout 之间的不匹配,使得前向传播过程偏离了 actor 在 rollout 期间实际会生成的内容。因此,RL 训练是“离策略”的、不稳定的,并且不太可能取得满意的结果。
4.1. 训练和 Rollout 之间统一的 FP8 精度流
Jet-RL 的核心设计。在 Jet-RL 中,我们提出通过强制 $G_{infer}$ 成为 $G_{fwd_{train}}$ 的一个子图来解决这个问题。边的所有其他属性(精度和粒度)都保持相同。唯一的区别是 $G_{fwd_{train}}$ 有一个更高精度的主副本以稳定训练,因为 $G_{train}$ 的主权重需要以 BF16 存储。我们确保训练和推理框架的前向传播共享一致的量化行为,从而缓解了在先前方法中观察到的精度传播不匹配问题。这在图5中得到了展示。
反向传播中的精度处理。考虑前向传播中的一条边 $(v, v')$,其中张量为 FP8 精度,且 $v'$ 是一个 GEMM 算子。$v'$ 算子只能访问 FP8 张量,因为量化通常与前一步骤融合。因此,我们选择将为反向传播保存的激活值也以 FP8 精度存储。在预训练和监督微调的先前工作中,这种策略已被证明可以保持训练稳定性【索引14, Coat: Compressing optimizer states and activation for memory-efficient fp8 training, 2024, arXiv】。
梯度精度。我们保留在反向传播过程中算子之间传输的梯度为 BF16 精度,以保持模型准确性。尽管量化它们可以进一步减少通信开销,但它们常常会引入梯度下溢或量化噪声,从而降低收敛性。反向传播中的 GEMM(DGrad 和 WGrad)也为了加速而被量化为 FP8 精度。我们将在下一节中详细阐述。
4.2. GEMM 量化的粒度
量化方案概览。我们量化了训练和推理中线性层的所有 GEMM 算子以加速计算。如图6所示,前向传播中的 FProp 算子以及反向传播中的 WGrad 和 DGrad 算子都以 FP8 张量作为输入,输出 BF16 张量。在本节中,我们讨论为这些 FP8 GEMM 采用的量化方案。
采用细粒度量化。FP8 的逐张量(per-tensor)量化已被证明在训练大型语言模型时是不稳定的。因此,在将激活、权重和梯度量化为 FP8 精度时,我们采用了更细粒度的量化方案。具体来说,我们使用 128 × 128 的逐块(per-block)量化来量化权重,并使用 1 × 128 的逐组(per-group)量化来量化激活和梯度。下面我们讨论每个 GEMM 的策略。
FProp 算子。对于 FProp 算子,输入激活使用 1 × 128 的逐组量化,而权重使用 128 × 128 的逐块量化。如 § 2.2 所述,硬件核需要是行主序 × 列主序,因此激活和权重都以行主序布局存储。激活的量化可以与其前一个算子融合以减少开销,而权重的量化需要在训练期间显式进行。对于推理中的权重 量化,我们在参数更新阶段进行权重 量化,这会产生可忽略的开销,并且与张量并行完全兼容。
DGrad 和 WGrad 算子。对于 DGrad 算子,其工作负载等同于 FProp 算子,即一个 1 × 128 量化矩阵乘以一个 128 × 128 量化矩阵。因此,它可以直接重用相同的核配置。对于 WGrad 算子,我们遵循 DeepSeek-V3 的做法,将第一个矩阵以 1 × 128 进行量化,而将第二个矩阵以 128 × 1 进行量化。这种更细粒度的设计有助于稳定训练。
反向传播中的量化细节。这两个算子都需要量化的梯度,但一个要求是 1 × 128 量化,另一个要求是 128 × 1 量化,因此我们融合了这些量化过程。由于权重的量化方案在通道和行轴上是对称的,其值在反向传播中不会改变,所以我们只需要在反向传播中执行一次转置。对于激活,在前向传播中它们被量化为 1 × 128,但在反向传播中,它们需要被量化为 128 × 1。这种差异迫使我们在反向传播中再次对它们进行量化。这对量化训练甚至是有益的。
4.3. 实现
技术栈。我们使用 vLLM【索引27, Efficient Memory Management for Large Language Model Serving with PagedAttention, 2023】作为推理引擎,VeRL【索引16, Hybridflow: A flexible and efficient rlhf framework, 2024, arXiv】作为我们的 RL 训练框架。对于量化的 GEMM,我们参考了 DeepGEMM【索引21, Deepgemm: Clean and efficient fp8 gemm kernels with fine-grained scaling, 2025, GitHub repository】中的核函数。我们使用 Triton【索引48, Triton: an intermediate language and compiler for tiled neural network computations, 2019】来实现量化、转置以及融合的激活或 RMSNorm 核函数。
A4 实验环境
-
模型:
- Llama3.1-8B, Qwen2.5-7B, Qwen3-8B-Base
- Rollout 长度设置为 8K 和 16K。
-
数据集:
- GSM8K + MATH 混合数据集: GSM8K【索引33, Training verifiers to solve math word problems, 2021, ArXiv】包含8,500个小学数学应用题,MATH【索引34, Measuring mathematical problem solving with the math dataset, 2021】包含12,500个复杂的数学竞赛题。此设置下,rollout 生成数量为4。
- DeepMATH 数据集: DeepMATH【索引49, Deepmath-103k: A large-scale, challenging, decontaminated, and verifiable mathematical dataset for advancing reasoning, 2025】包含103,000个高难度数学问题。此设置下,rollout 生成数量为16。
-
硬件配置:
- NVIDIA H100 GPUs。
-
软件配置与超参数:
- RL 训练框架: VeRL
- 推理引擎: vLLM
- 量化库: DeepGEMM, Triton
- 学习率 (Learning Rate): 10⁻⁶
- 批次大小 (Batch Size): 256
- KL 损失系数: 10⁻³
- 评估频率: 每5个训练步骤进行一次检查点评估。
A4 实验结果
5.2. 准确率评估
8K Rollout 长度下的性能。如表2所示,当在 GSM8K + MATH 数据集上使用 8K rollout 长度训练模型时,BF16-train-FP8-rollout 方法表现出显著的不稳定性。最值得注意的是,它在 Qwen2.5-7B 模型上完全无法收敛。在能够收敛的模型上,与 BF16 基线相比,它也带来了巨大的性能下降。例如,在 Llama3.1-8B 上,平均得分下降了 9.8%,在 Qwen3-8B-Base 上下降了 2.9%。相比之下,我们的 Jet-RL 方法证明是稳健且有效的。它不仅在所有场景中都能收敛,而且大大缩小了与 BF16 训练的差距。在 Llama3.1-8B 上,Jet-RL 甚至比 BF16 基线高出 2.0%。在 Qwen2.5-7B 上,性能下降仅为 1.0%(56.9% vs 55.9%),在 Qwen3-8B-Base 上仅下降 1.1%(63.8% vs 62.7%)。
Table 2 | 当 rollout 长度设置为 8k 时,Jet-RL 和基线的性能比较。在所有模型设置中,Jet-RL 极大地减少了 BF16-train-FP8-rollout 引入的性能下降,并实现了接近 BF16 RL 训练的性能。绿色(红色)微数字表示在每个模型块内相对于 BF16 的绝对增益(下降);空白表示基线未收敛。
更具挑战性配置下的性能。如表3所示,我们在更具挑战性的配置下分析了性能,包括 16K rollout 长度和 DeepMATH 数据集。这进一步凸显了 BF16-train-FP8-rollout 方法的不稳定性,特别是在更具挑战性的配置下,例如 16K rollout 长度或不同的训练数据集(DeepMATH)。BF16-train-FP8-rollout 方法在 Qwen3-8B-Base 模型上使用 16K rollout 无法收敛。在 Qwen3-8B-Base (DeepMATH) 实验中,与 BF16 基线相比,它遭受了 10.3% 的严重性能下降。在 Qwen2.5-7B (16K) 模型上,它也显示出 5.0% 的显著下降。相比之下,Jet-RL 成功解决了这些问题。它在 Qwen3-8B 模型上收敛,将差距缩小到仅 2.7%。在 Qwen3-8B-Base (DeepMATH) 实验中,Jet-RL 将差距缩小到 0.9%(54.6% vs 53.7%)。在 Qwen2.5-7B 模型上,性能下降也从 5.0% 减少到 3.0%。结果表明,与 BF16-train-FP8-rollout 相比,Jet-RL 框架实现了更稳定的收敛和更好的性能,并且 Jet-RL 的性能与 BF16 基线紧密对齐。
Table 3 | 当 rollout 长度为 16k 或在 DeepMATH 上训练时,Jet-RL 和基线的性能比较。在所有模型和数据集的设置中,Jet-RL 极大地减少了由 BF16-trainFP8-rollout 引入的性能下降,并实现了接近 BF16 RL 训练的性能。
5.3. 效率评估
Rollout 效率。我们通过在 vLLM【索引27, Efficient Memory Management for Large Language Model Serving with PagedAttention, 2023】中进行离线生成基准测试来量化 FP8 带来的 rollout 效率增益,报告了 FP8 带来的吞吐量加速。我们改变了模型大小从 8B 到 32B,并在多种张量并行设置下进行了测试。如表4所总结,FP8 实现了持续的加速,其加速比 BF16 高出 1.07× 到 1.33× 不等。
Table 4 | 在不同输出长度下,FP8 相对于 BF16 在 rollout 中的加速比(tokens/s)。使用了 512 个提示,最大并发请求为 128,输入长度为 512。在 H100s 上测量。
加速比趋势分析。我们观察到加速比有两个趋势。首先,FP8 量化的好处随着模型大小的增加而增加。较大的模型,如 32B 配置,实现了最显著的加速,高达 1.33 倍,因为它更具计算密集性。这使得优化的 FP8 张量核(例如 DeepGEMM)能够更有效地加速计算。在较小的模型(例如 8B)中,内存访问开销占延迟的较大部分,限制了 FP8 推理的整体效益。其次,更高程度的张量并行(TP)降低了观察到的加速比。当模型分布在更多 GPU 上时,通信开销变得更加突出。例如,TP=4 的 32B 模型仅显示出 1.1 倍的改进,而 TP=2 时为 1.3 倍。这一趋势表明,在较低的张量并行度下,可以加速 FP8 rollout。
端到端效率。对于在 Qwen3-8B 上进行 8K rollout 长度的端到端 RL 训练,FP8 量化在多个计算阶段提供了持续的加速。具体来说,FP8 在 actor 更新阶段实现了 1.54× 的加速,在 reference 模型推理中实现了 1.80× 的加速,这共同促成了训练阶段吞吐量 1.41× 的整体提升。结合 rollout 的加速,这使得 8B 模型训练的端到端单步时间加速达到 1.16×。我们预计对于更大的模型尺寸,加速效果会更加显著。由于资源限制,对 14–32B 模型的全面扩展研究留待未来工作。
A7 补充细节
6.1. LLM 的低精度训练与推理
低精度技术的动机。大型语言模型参数数量的指数级增长使得全精度训练和推理变得极其昂贵,在内存、计算和能耗方面造成了巨大障碍【索引52, Jet-nemotron: Efficient language model with post neural architecture search, 2025, arXiv】【索引53, Hymba: A hybrid-head architecture for small language models, 2024, arXiv】【索引54, Nemotron-h: A family of accurate and efficient hybrid mamba-transformer models, 2025, arXiv】。这一根本性挑战催生了低精度技术的发展,这些技术将模型权重、激活和梯度的数值精度降低到较低的格式,以利用 NVIDIA Tensor Cores 等硬件加速器,这些加速器为低精度算术提供了显著更高的吞吐量。
训练后量化(PTQ)。训练后量化(PTQ)已成为部署 LLM 的主流方法,它以“一次性”的方式压缩预训练模型而无需微调。SmoothQuant【索引37, Smoothquant: Accurate and efficient post-training quantization for large language models, 2022, ArXiv】通过将量化难度从激活迁移到权重来解决激活离群值问题。GPTQ 将 1750 亿参数的模型压缩到 3 或 4 位,对困惑度的影响微乎其微。AWQ 通过在激活感知、无需重构的方法中应用逐通道缩放因子来保护显著权重。
全量化训练(FQT)。另一方面,全量化训练(FQT)通过在低精度下执行计算来加速 LLM 的训练阶段。SwitchBack 和 Jetfire【索引24, Jetfire: Efficient and accurate transformer pretraining with int8 data flow and per-block quantization, 2024, arXiv】通过 INT8 精度流来优化内存访问和逐块量化方法来保持准确性,为 Transformer 提出了先进的 INT8 训练方法。NVIDIA Transformer Engine【索引55, Transformer engine: A library for accelerating transformer models on nvidia gpus, 2025】使用 FP8 加速 Transformer 模型训练。在此基础上,像 COAT 这样的框架将 FP8 量化扩展到线性层之外,包括优化器状态和激活。与此同时,像 QLoRA 这样的方法专注于使用 4 位量化冻结的预训练模型和低秩自适应(LoRA)进行内存高效的微调。
6.2. 大型语言模型的强化学习
RL 在 LLM 中的演变。强化学习(RL)在大型语言模型(LLM)的演变中扮演了核心角色,从对齐发展到推理。早期的对齐工作采用了来自人类反馈的强化学习(RLHF),其中在人类偏好数据上训练的奖励模型通过近端策略优化(PPO)等算法指导策略优化。后续的方法,如直接偏好优化(DPO),去除了显式的奖励模型,实现了更简单、更稳定的基于偏好的微调。
面向推理的 RL。近期的努力已转向面向推理的 RL,其中模型输出的正确性可以自动验证。这一转变推动了为推理任务量身定制的新 RL 算法的开发。例如,DeepSeek-R1 模型是使用 GRPO(组相对策略优化)开发的,这是一种无评论家(critic-free)的 RL 算法,通过比较一个响应的奖励与从同一提示生成的响应组的平均奖励来估计优势。GSPO 被引入以通过在序列级别执行优化来增强训练稳定性。DAPO 引入了几种技术,包括“Clip-Higher”以避免熵崩溃和“Dynamic Sampling”以过滤掉无信息的训练样本,以获得更好的稳定性和效率。
6.3. 大型语言模型的高效推理
RLVR 的计算密集性。虽然 RLVR 是训练推理模型的强大范式,但由于需要生成长的自回归 rollout 以及 RL 的“在策略”性质,它在计算上是密集的,可能导致硬件利用率低。这推动了对系统级和算法级优化的研究,以使推理更高效。
系统级优化。从系统角度看,已经提出了像 AReaL 这样的异步 RL 框架,以打破 rollout 和训练之间的同步依赖。这些异步系统允许 rollout 工作者持续生成新数据,而训练工作者则更新模型,从而提高了 GPU 利用率并显著增加了训练吞吐量。ReaLHF 专注于改进 RL 训练的并行策略。通过自动化并行化搜索过程,ReaLHF 提高了 GPU 利用率并减少了 GPU 空闲时间。
算法级优化。从算法角度看,一个关键挑战是“过度思考”现象,即模型生成过长和冗余的推理轨迹。NoThinking【索引56, Reasoning models can be effective without thinking, 2025】提出修剪推理轨迹,并通过识别冗余步骤和生成鼓励更早终止的训练信号,教导模型自我调节其推理过程。
协同设计。最后,协同设计为解决这些效率挑战提供了另一个角度。QeRL 框架【索引57, Qerl: Beyond efficiency – quantization-enhanced reinforcement learning for llms, 2025】通过将 NVFP4 量化与低秩自适应(LoRA)相结合,解决了 RL rollout 的高成本问题。这种协同作用极大地减少了内存占用并加速了生成阶段,使得在单个 GPU 上高效训练 32B 模型成为可能。截断重要性采样(TIS)【索引47, Flashrl: 8bit rollouts, full power rl, 2025】提出通过在模型更新中添加重要性比率并在推理概率较小时截断它来缓解“离策略”问题。
A5 结论
本文解决了 RL 训练中 rollout 阶段的关键性能瓶颈。我们证明了朴素的 BF16 训练 + FP8 rollout 策略存在根本性缺陷,因为它引入了训练与 rollout 的不匹配,导致训练不稳定和灾难性的性能崩溃。为了解决这个问题,我们提出了 Jet-RL,一个通过为训练前向传播和推理 rollout 阶段采用相同的 FP8 精度流,从而实现稳健的“在策略”FP8 RL 训练的框架。我们的综合评估表明,Jet-RL 在所有模型和基准测试设置中都能稳健收敛。我们的方法还保持了与 BF16 RL 基线相当的竞争性能,通常性能下降不到 1%。通过提供高达 1.33 倍的 rollout 阶段加速、高达 1.41 倍的训练阶段加速以及 1.16 倍的端到端加速,且不牺牲模型准确性,Jet-RL 为应用 FP8 计算来加速大规模 RL 训练建立了一条可靠而高效的前进道路。
💬 评论讨论
欢迎在这里分享您的想法和见解!