AREAL-DTA: Dynamic Tree Attention for Efficient Reinforcement Learning of Large Language Models

A1 主要贡献

基于强化学习(RL)的大型语言模型(LLM)后训练计算成本高昂,因为它会生成大量可能频繁共享长token前缀的rollout序列。现有RL框架通常独立处理这些序列,在策略模型训练的前向和后向传播过程中重复计算相同的前缀,导致计算和内存使用效率低下。尽管前缀共享自然地在rollout上形成了树状结构,但先前基于树注意力的解决方案依赖于完全物化的注意力掩码,在RL场景中扩展性很差。

本文介绍的 AREAL-DTA 旨在高效利用RL训练中的前缀共享。AREAL-DTA 采用一种基于深度优先搜索(DFS)的执行策略,在前向和后向计算过程中动态遍历rollout前缀树,每次只物化一条从根到叶的路径。为了进一步提高可扩展性,AREAL-DTA 整合了一种负载均衡的分布式批处理机制,该机制可在多个GPU上动态构建和处理前缀树。

本文的核心贡献如下:
1. 创新的DFS遍历方法:设计并实现了一种创新的token前缀树DFS遍历方法。AREAL-DTA不构建完全物化的注意力掩码,而是将rollout集合视为前缀树,并使用DFS动态遍历。在DFS中推入(push)前缀树节点时,AREAL-DTA一次只探索前缀树的一个分支,重用共享前缀的计算,仅在分支发散时分配新的计算资源。在任何时刻,活跃的上下文仅限于从根(公共前缀)到叶的单条路径,这极大地减少了内存占用。在DFS中弹出(pop)节点时,AREAL-DTA在访问下一个兄弟节点之前,会立即沿该分支反向传播相应的梯度。这样,中间激活值无需为所有分支同时存储,共享前缀的梯度贡献也能在不重复计算的情况下被正确累积。这种DFS遍历确保每个前缀树节点的计算只进行一次并为其所有后续延续所重用,同时,每个前缀的梯度也能被正确聚合,而使用的内存与最长序列的深度成正比,而不是前缀树中token的总数。

  1. 负载均衡的分布式批处理策略:为进一步扩展RL训练,开发了一种负载均衡的分布式批处理策略。具体来说,AREAL-DTA在异步rollout生成阶段动态地对生成的rollout进行批处理,以构建多个前缀树,并将计算分布到多个训练GPU工作节点上,从而保持每个训练GPU的计算负载均衡。这种机制最大限度地减少了GPU的空闲时间,从而实现了一种可扩展的训练机制,能够高效地为RL rollout构建和遍历前缀树,即使在序列数量众多、轨迹很长的情况下也能有效工作。

  2. 显著的性能提升:在广泛的RL训练任务中实现了显著的性能提升。实验证明,AREAL-DTA在具有挑战性的RL微调基准测试中,在速度和内存效率方面都取得了实质性的提升。在流行的RL任务上评估AREAL-DTA,发现其始终优于标准的RL基线。例如,AREAL-DTA实现了显著的吞吐量提升,单个训练工作节点最高可达8.31倍,整个训练集群可达6.20倍,完整的端到端流水线可达2.28倍。这种加速可归因于其内存高效的设计。与传统方法相比,观察到内存使用量大幅减少(通常将峰值GPU内存削减超过50%),从而避免了使用辅助内存优化技术(如激活检查点或梯度累积)的需要,这些技术否则会引入不可忽略的计算开销。

图1 AREAL-DTA的可视化说明。AREAL-DTA将序列(s1-s4)组织成一个前缀树,并维护一个活动中的从根到叶的前缀堆栈(tokens + KV缓存),同时推入新的段并重用共享前缀(例如,seg1和seg2)。在每个叶子节点,AREAL-DTA立即计算损失,进行反向传播,并弹出该分支——丢弃仅属于叶子的激活值(例如,seg3),同时保留共享前缀的状态以供下一个兄弟节点使用。注意,每个节点代表序列段中的一组token。
图1 AREAL-DTA的可视化说明。AREAL-DTA将序列(s1-s4)组织成一个前缀树,并维护一个活动中的从根到叶的前缀堆栈(tokens + KV缓存),同时推入新的段并重用共享前缀(例如,seg1和seg2)。在每个叶子节点,AREAL-DTA立即计算损失,进行反向传播,并弹出该分支——丢弃仅属于叶子的激活值(例如,seg3),同时保留共享前缀的状态以供下一个兄弟节点使用。注意,每个节点代表序列段中的一组token。

A3 背景知识

2.1 用于LLM后训练的RL系统

RL系统用于LLM后训练的阶段。RL已被广泛用于提升大型语言模型(LLM)的推理能力【索引10, OpenAI. Learning to reason with llms. OpenAI Blog, 09 2024.】【索引11, OpenAI. Introducing openai o3 and o4-mini. OpenAI Blog, 04 2025.】。先前的研究表明,基于RL的后训练可以在包括数学推理、程序合成和多跳问答在内的广泛推理密集型任务上显著提升性能【索引3, Dan Hendrycks, et al. Measuring mathematical problem solving with the math dataset. arXiv preprint arXiv:2103.03874, 2021.】【索引4, Yiping Wang, et al. Reinforcement learning for reasoning in large language models with one training example. arXiv preprint arXiv:2504.20571, 2025.】【索引6, Mark Chen, et al. Evaluating large language models trained on code. arXiv preprint arXiv:2107.03374, 2021.】【索引7, Jiaxuan Gao, et al. Beyond ten turns: Unlocking long-horizon agentic search with large-scale asynchronous rl. arXiv preprint arXiv:2508.07976, 2025.】。从系统角度看,LLM的RL训练对资源要求极高,通常包括三个不同阶段:(i)rollout生成,在受HBM I/O限制的GPU上执行推理,为每个提示生成多个候选响应(rollouts);(ii)奖励评估,可能依赖于密集的CPU资源(如用于代码评估的沙箱执行或用于数学的基于规则的求解器)或在使用基于LLM的奖励或价值模型时需要额外的GPU资源;(iii)模型训练,通过随机梯度优化对策略(和可选的价值)模型进行计算密集型的GPU更新,有时会涉及一个参考模型以保证稳定性。

同步与异步RL训练范式。现有的LLM RL训练流水线通常分为同步或异步两种范式。在同步RL训练中,rollout生成和模型优化交替进行:首先使用当前的策略模型参数生成推理轨迹,然后利用生成的rollout来更新模型【索引1, Qinghao Hu, et al. Taming the long-tail: Efficient reasoning rl training with adaptive drafter. arXiv preprint arXiv:2511.16665, 2025.】【索引12, John Schulman, et al. Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347, 2017.】【索引13, Guangming Sheng, et al. Hybridflow: A flexible and efficient rlhf framework. In Proceedings of the Twentieth European Conference on Computer Systems, pages 1279–1297, 2025.】【索引14, Ruoyu Qin, et al. Seer: Online context learning for fast synchronous llm reinforcement learning. arXiv preprint arXiv:2511.14617, 2025.】。相比之下,异步RL训练允许这些阶段并行进行,其中rollout生成使用可能过时的参数持续产生轨迹,而训练过程则并行更新模型【索引2, Wei Fu, et al. Areal: A large-scale asynchronous reinforcement learning system for language reasoning. arXiv preprint arXiv:2505.24298, 2025.】【索引15, Volodymyr Mnih, et al. Asynchronous methods for deep reinforcement learning. In International conference on machine learning, pages 1928–1937. PMLR, 2016.】【索引16, Lasse Espeholt, et al. Impala: Scalable distributed deep-rl with importance weighted actor-learner architectures. In International conference on machine learning, pages 1407–1416. PMLR, 2018.】【索引17, Lasse Espeholt, et al. Seed rl: Scalable and efficient deep-rl with accelerated central inference. arXiv preprint arXiv:1910.06591, 2019.】【索引18, Zhiyu Mei, et al. Srl: Scaling distributed reinforcement learning to over ten thousand cores. arXiv preprint arXiv:2306.16688, 2023.】【索引19, Guangming Sheng, et al. Laminar: A scalable asynchronous rl post-training framework. arXiv preprint arXiv:2510.12633, 2025.】【索引20, Haoyang Li, et al. Unleashing efficient asynchronous rl post-training via staleness-constrained rollout coordination. arXiv preprint arXiv:2601.12784, 2026.】。在这些异步系统中,AReaL【索引2, Wei Fu, et al. Areal: A large-scale asynchronous reinforcement learning system for language reasoning. arXiv preprint arXiv:2505.24298, 2025.】通过完全异步的架构进一步将流式生成与训练解耦,并引入了诸如时延感知优化和解耦的RL目标等算法技术,以实现LLM推理工作流的高效稳定RL训练。

2.2 用于LLM推理和训练的树注意力机制

用于LLM推理和训练的树注意力机制。将多个序列组织成前缀树的方法最初是为了通过并行化推测性解码和重用候选输出间的共享前缀来加速LLM推理【索引21, Ziteng Sun, et al. Spectr: Fast speculative decoding via optimal transport. Advances in Neural Information Processing Systems, 36:30222–30242, 2023.】。例如,Specinfer【索引9, Xupeng Miao, et al. Specinfer: Accelerating large language model serving with tree-based speculative inference and verification. In Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 3, pages 932–949, 2024.】首次提出将候选token组织成一个token树,并通过单次模型传递并行验证它们,从而能够在每次迭代中验证更多token。基于类似视角,Medusa【索引8, Tianle Cai, et al. Medusa: Simple llm inference acceleration framework with multiple decoding heads. In International Conference on Machine Learning, pages 5209–5235. PMLR, 2024.】通过增加额外的解码头来增强原始LLM,以一步预测多个后续token,并采用树状结构的注意力掩码来在每一步同时构建和验证多个延续分支。除了这些草稿-验证框架,近期的研究也优化了树解码图本身以提高效率。例如,Sequoia【索引22, Zhuoming Chen, et al. Sequoia: Scalable and robust speculative decoding. Advances in Neural Information Processing Systems, 37:129531–129563, 2024.】应用一种基于搜索的策略,在树的深度和宽度上分配固定的token预算,以在给定成本约束下最大化前缀重用。Yggdrasil【索引23, Yue Guan, et al. Yggdrasil: Bridging dynamic speculation and static runtime for latency-optimal tree-based llm decoding. arXiv preprint arXiv:2512.23858, 2025.】将动态推测与静态运行时优化相结合,动态选择树的宽度(即并行分支)和每个查询的深度,同时使用“等增长”树结构和分阶段调度来保持高硬件利用率。互补的优化也针对树注意力的内存和计算开销【索引24, Jinwei Yao, et al. Deft: Decoding with flash tree-attention for efficient tree-structured llm inference. In The Thirteenth International Conference on Learning Representations, 2025.】【索引25, Zaifeng Pan, et al. Fasttree: Optimizing attention kernel and runtime for tree-structured llm inference. Proceedings of Machine Learning and Systems, 7, 2025.】。例如,FastTree【索引25, Zaifeng Pan, et al. Fasttree: Optimizing attention kernel and runtime for tree-structured llm inference. Proceedings of Machine Learning and Systems, 7, 2025.】引入了专门的注意力内核,对共享公共前缀的分支进行计算打包和分块,减少了冗余的键/值加载和内存访问。综合来看,共享前缀状态的重用、并行解码头或分支探索多个token延续——结合通过自定义注意力掩码、内核和调度的优化执行——可以最小化运行时间和内存开销。

树注意力在LLM训练中的应用。与在LLM推理中的广泛应用相比,基于树注意力的LLM训练研究相对较少——我们发现唯一相关的尝试是Tree Training【索引5, Shaojie Wang, et al. Tree training: Accelerating agentic llms training via shared prefix reuse. arXiv preprint arXiv:2511.00413, 2025.】,它也旨在通过重用分支轨迹间的计算来微调LLM。具体来说,Tree Training通过专门的轨迹树打包机制和梯度校正机制实现。然而,由于其在内存占用、计算效率和可扩展性方面的限制,它不足以用于实际的大规模RL——Tree Training通过将每个打包的轨迹树存储在GPU内存中,显著增加了内存使用量,并且其带有自定义内核的静态前缀打包方案在没有有效并行训练支持的情况下不易应用。相比之下,AREAL-DTA使用动态DFS遍历和负载均衡的分布式调度来解决RL训练中的这些基本挑战。

A2 方法细节

3 动态树注意力

问题形式化。我们将RL计算的策略训练形式化如下。令 $s_1, s_2, ..., s_N$ 为用于当前策略模型训练迭代的N个rollout序列集合。每个序列 $s_i$ 都有一个相关的损失信号 $L(s_i)$,例如负对数似然或RL策略梯度信号,总的训练目标定义为:

$$\mathcal{L}=\sum_{i=1}^{N} \mathcal{L}\left(\mathbf{s}_{i}\right)$$

为了利用这些rollout中的共享前缀,我们构建一个表示为T的前缀树,它紧凑地表示了所有序列及其共享前缀。T中的每个节点对应于部分序列共享的token段,每个从根到叶的路径对应一个完整的序列,即$s_i$。通过遍历这个前缀树,我们可以重用不同rollout间的公共前缀计算。值得注意的是,任何多个序列共有的前缀将在前向传播中只处理一次,并且其来自所有后代序列的梯度应在一次反向传播中正确累积,而不是为每个序列重复计算。

主要挑战。关键挑战在于如何在不引入显著内存开销或损害梯度计算正确性的情况下实现这种重用。AREAL-DTA通过以深度优先方式动态遍历所有序列的前缀树,并交错进行前向和后向传播来重用计算,从而解决了这一挑战(§3.1)。我们还进行了一系列系统优化,以提高计算效率并减少内存开销(§3.2)。

3.1 前缀树的深度优先搜索遍历

深度优先搜索遍历的实现。为了在AREAL-DTA中实现DFS遍历算法,我们维护一个堆栈,该堆栈表示从前缀树的根到我们当前正在访问的节点的当前路径(前缀)。在任何时候,这个堆栈都持有:(i)当前前缀中的token序列;(ii)由当前策略模型为该前缀生成的相应中间状态,即Transformer的这些token的键/值缓存。给定前缀树T,训练计算(即计算策略梯度的前向和后向传播)按如下方式进行:

推入前缀树中间节点。我们首先对T执行深度优先搜索(DFS),即在沿着每个前缀树分支向下走时,将前缀树的中间节点推入堆栈。注意,推入当前前缀树中间节点意味着扩展当前前缀,并从先前前缀的模型状态继续,为该中间节点中的新token执行前向计算。具体来说,当我们从一个节点(前缀)移动到它的一个子节点时,使用前缀的缓存键/值(KV)状态将该子节点中的token输入策略模型。这个过程计算新token的对数概率,并更新扩展后前缀的KV缓存。然后,我们将新token的KV缓存附加到堆栈上。这个过程有效地重用了前缀token的计算,因此我们不必重新计算它们。

访问前缀树叶子节点。当DFS遍历到达一个对应于完整序列si的叶子节点时,堆栈现在包含了完整的token序列si及其前向传播结果(例如,对数概率、熵)。此时,我们可以利用堆栈中的信息计算损失$L(s_i)$(即通过对正确token的负对数似然求和或在轨迹上应用RL奖励)。然后,我们立即在该序列的输出端注入损失梯度——有效地开始了分支si的反向传播过程。通过在一个序列的前向传播完成后立即执行此操作(而不是在处理完所有序列之后),我们确保一旦该序列的梯度被反向传播,其计算图就不再需要保留在内存中。

弹出前缀树中间节点。处理完一个叶子节点后,DFS遍历会沿着该分支进行反向传播,然后从前缀树中向上弹出。通过这个弹出操作,我们获取前一个$L(s_i)$的梯度,并将其通过当前分支的策略模型计算进行传播。至关重要的是,当我们反向传播通过一个作为前缀树中分支节点的节点(对应于多个序列共享的token段)时,来自所有这些序列损失的梯度将累积在该前缀节点和模型的参数中。我们通过按DFS遍历顺序对每个分支依次执行反向传播来正确处理这个问题——共享前缀节点将多次接收梯度贡献(每个后代叶子节点对应一个原始序列$s_i$),我们在进行DFS遍历时将这些贡献相加。一旦我们反向传播了当前前缀树叶子节点计算的梯度,我们就会将其对应的token和相关激活值从堆栈中弹出,恢复到父前缀状态。此时,与被弹出token对应的节点在计算图中变得不再需要:我们知道未来的计算将不再依赖于它们(因为由于DFS遍历,我们已经完成了该前缀下的所有分支),并且我们已经为它们注入了所有必要的梯度。因此,这些激活值和任何临时梯度都可以安全地删除。然后,我们使用仍然保留的前缀状态,继续DFS到下一个兄弟分支。

空间复杂度分析。通过使用这种DFS遍历机制,AREAL-DTA仅存储从根到叶的当前路径的KV缓存(以及堆栈的少量状态)。我们永远不需要同时持有整个树的激活值。这意味着峰值内存使用量与最长序列的长度(沿着前缀树的最长路径)成正比,而不是与所有序列中的token总数成正比。在流行的RL工作流中,数百个序列一起处理,这是一个关键的改进——AREAL-DTA避免了朴素的树注意力实现会带来的二次方内存爆炸。

3.2 系统优化

系统优化概述。我们为DFS遍历实现了一系列系统优化,以进一步减少内存使用并提高计算效率。

内存高效的梯度计算。AREAL-DTA中的一个关键优化是在DFS遍历期间动态执行反向传播,以限制内存增长。AREAL-DTA不是为所有序列构建一个巨大的计算图,然后进行一次全局反向传播,而是交错进行前向和后向步骤,以便部分计算图被增量地构建和释放——我们动态地构建计算图,并在子部分不再需要时立即销毁它们。具体来说,考虑前缀树中的一个中间节点:在标准的基于批处理或朴素的基于树的方法中,计算图可能会将中间激活值(例如,隐藏状态、KV缓存、logits等)保留在内存中,直到所有序列的损失L都已处理并反向传播;在AREAL-DTA的DFS方案中,一旦我们在前缀树中弹回该中间节点,这意味着我们已经遍历了该中间节点下的所有叶子节点,因此可以确保该中间节点不会影响任何来自其他序列片段的进一步梯度计算。因此,这样的节点可以立即进行反向传播并释放。这确保我们不保留整个前缀树的计算图。相反,在任何时候,未释放的图部分仅用于我们当前正在下降的路径。该策略将理论内存复杂度从与总树大小(所有序列长度之和)成正比降低到仅与最长序列长度成正比。

针对长序列的分块反向传播。我们受ChunkFlow【索引26, Xiulong Yuan, et al. Efficient long context fine-tuning with chunk flow. In Forty-second International Conference on Machine Learning, 2025.】中分块调度算法的启发,为前缀树实现了长rollout序列的分块反向传播机制。虽然DFS方法将活动计算图限制在一个分支上,但单个序列仍可能非常长(数万个token),这可能使其前向/后向图变得太大而无法放入内存。为了处理极长的rollout轨迹,AREAL-DTA进一步将这种序列的反向传播分块。我们选择一个最大块长度(例如2048个token),并将长序列的前向/后向计算分割成该长度的块。AREAL-DTA不会等到长序列结束才进行反向传播,而是在处理完一块token后执行部分反向传播,在继续进行进一步计算之前释放这些激活值。具体来说,如果一个前缀路径超过了块长度,我们将进行部分反向传播:使用存储的前缀KV缓存和输出,我们重新计算该块的前向激活值(一个小的额外前向传播)以重建其计算图,然后立即反向传播并释放该块的图。然后我们继续该块之后的正向传播,并根据需要重复该过程。请注意,这种优化的成本是每个块的额外前向计算(因为我们为反向传播重新计算块激活值),但通过选择一个合理大的块长度,我们可以分摊额外内存IO的开销,在保持内存使用受控的同时增加最小的运行时间。

避免叶子节点的KV缓存计算。在DFS遍历中,AREAL-DTA避免在堆栈中存储前缀树叶子节点的KV缓存。由于只有前缀节点可以在分块反向传播执行期间作为计算图重建的锚点,叶子节点的KV缓存永远不会被任何其他计算重用。通过跳过这些终端节点的KV缓存,我们消除了仅用于填充缓存的不必要的前向计算。这显著减少了分块反向传播引入的开销,使其额外成本几乎可以忽略不计。

确定最优DFS遍历顺序。请注意,对前缀树的DFS遍历顺序会显著影响整体系统效率。在AREAL-DTA中,我们设计了一个贪心算法来优化DFS序列。这个贪心启发式算法简单而有效:(i)通过最大化共享前缀的重用来最小化前向传播的次数;(ii)平衡反向传播段的长度以避免频繁的短后向步骤,后者可能会受限于内存。通过优先处理导致更长不间断推入-弹出段的分支,并更均匀地平衡梯度累积,我们的贪心算法在不同形状的rollout树上提高了内存效率和运行时稳定性。

4 负载均衡并行化

负载均衡并行化概述。尽管上述动态树注意力算法优化了单工作节点的效率,但大规模RL训练还需要跨多个GPU进行分布式执行。在异步RL框架(即AREAL)中,rollout被持续生成并将被分派到多个训练GPU上进行并行训练。对于每个策略模型训练迭代,AREAL-DTA利用一个负载均衡的调度器,将传入的rollout分配给不同的GPU,使得每个GPU执行的工作量均衡,从而最大化整体吞吐量并避免空闲。

问题形式化。我们将调度问题形式化如下。假设在每个策略模型训练迭代中,我们从rollout推理工作节点收集N个新的rollout来训练策略模型。我们希望将这N个序列分成K个不相交的组(其中K是训练GPU的数量),并让每个GPU为其组中的序列构建和处理一个前缀树,记为$T_j, j = 1, 2, ..., K$。一个组的成本,记为$C(T_j)$,定义为处理其序列作为前缀树的估计时间。我们的目标是将n个序列划分为k个组,使得各组中的最大成本最小化:

$$\mathcal{C}=\min \max_{j=1}^K \mathcal{C}(\mathcal{T}_j)$$

在实践中,我们将$C(T_i)$定义为$T_i$中所有节点的token长度之和,即树中token的总数。

优化目标。这个优化目标确保没有单个GPU过载而其他GPU利用不足,即所有GPU将大致在同一时间完成工作,实现良好的扩展性。注意,这样的问题可以看作是平衡划分问题的一个变体,通常是NP-hard的。为了解决这个优化问题,我们引入了下面的负载均衡划分算法。

负载均衡划分算法。我们利用前缀树的特性和单调成本模型来高效地找到一个接近最优的划分。首先,我们将N个序列排列成一个单一的前缀树(就好像我们要在单个GPU上处理它们一样)。为此,我们简单地采用字典序排序,这显然是前缀树的DFS顺序。在实践中,当跨GPU分割序列时,我们考虑两种类型的开销:(i)序列组之间的负载不平衡;(ii)跨组的前缀计算重复(即,前缀共享的减少超过了组合前缀树的大小)。一种与结构无关、纯粹按token数量平衡序列的划分方法,会因将共享前缀分割到不同GPU而急剧增加开销(ii)。相反,通过DFS遍历顺序(它将具有公共前缀的序列相邻地聚类)对序列进行排序,然后将这个有序列表划分为K个连续的段,可以产生最小的重复。直观地说,连续的DFS段保留了大部分的前缀共享:分割成K个段最多引入(K-1) × max(len(si))个额外的token,超出了单个组合前缀树的范围(其中max(len(si))是最长序列的长度),这相对于总token数来说是一个可以忽略的开销。因此,我们可以有效地将前缀树划分问题简化为一个更简单的序列划分问题,而前缀重用的减少很小。即使有这种DFS顺序约束,序列也可以被分成大小几乎相等的连续块,从而在所有GPU上保持负载均衡。

二分搜索求解。基于此,我们只需在有序列表中找到K个段之间的边界,以最小化最大成本。我们通过对最大允许成本进行二分搜索并结合贪心检查来解决这个问题,这是划分问题的常用方法。具体来说,我们二分搜索最小的阈值τ,使得我们可以将序列列表切割成至多K个段,每个段的成本≤ τ。对于一个候选τ,我们按顺序扫描序列,并在添加下一个序列会超过成本τ时贪婪地开始一个新的段。如果我们能以这种方式形成≤ K个段,则阈值τ是可行的;否则,τ设置得太低。二分搜索τ(其范围从最大单个序列的成本到所有序列的总成本)能在O(N log C(T))时间内找到最优的最大成本,这在实践中非常快。我们注意到,在贪心扫描期间可以增量地计算一个段(一组序列)的成本:我们可以在添加每个序列时维护前缀树的成本,更新诸如共享前缀长度之类的信息。由于单调成本属性,这个贪心算法能够为给定的顺序找到一个最优的连续划分。

A4 实验环境与结果

5.1 实验设置

5.2 端到端RL训练性能(RQ1)

由于AREAL-DTA提高了策略模型的训练吞吐量,rollout GPU工作节点和训练GPU工作节点的最佳分配会有所不同。表1展示了手动调优后的最佳并行RL训练配置。

最佳并行配置

表1 AREAL-DTA与AREAL在PPO算法下训练1.7B和8B参数模型时的最佳异步RL训练配置比较。我们使用dpx-ppy-tpz格式表示rollout生成和模型训练的并行策略,其中x, y, z分别代表数据并行、流水线并行和张量模型并行的度数。在rollout生成上下文中,数据并行度为x表示将相同的rollout工作节点复制x次。
表1 AREAL-DTA与AREAL在PPO算法下训练1.7B和8B参数模型时的最佳异步RL训练配置比较。我们使用dpx-ppy-tpz格式表示rollout生成和模型训练的并行策略,其中x, y, z分别代表数据并行、流水线并行和张量模型并行的度数。在rollout生成上下文中,数据并行度为x表示将相同的rollout工作节点复制x次。

RL奖励曲线

图2展示了AREAL-DTA和AREAL在TauBench数据集上,使用1.7B和8B模型时,以RL训练步数为x轴的奖励曲线。由于异步RL训练固有的不确定性,对于相同模型大小,图2中的奖励曲线在AREAL-DTA和AREAL基线之间相似但不完全相同。这些结果表明AREAL-DTA的设计没有损害异步RL训练的稳定性。

图2 跨训练步骤,我们展示了AREAL-DTA和AREAL在TauBench数据集和1.7B/8B模型上的奖励比较。
图2 跨训练步骤,我们展示了AREAL-DTA和AREAL在TauBench数据集和1.7B/8B模型上的奖励比较。

端到端RL训练吞吐量

图3展示了使用累计真实世界RL训练时间作为x轴时,AREAL-DTA和AREAL的奖励曲线。图3表明,AREAL-DTA显著提高了RL训练效率:在图3的最终RL训练步骤中,AREAL-DTA相比AREAL在1.7B和8B模型上分别实现了1.28倍和2.28倍的端到端训练吞吐量提升。如表1中的训练配置所示,由于模型训练性能更优,AREAL-DTA可以将更多GPU分配给rollout生成阶段,从而提高端到端加速。另一方面,AREAL中训练GPU的较低训练效率需要将更多GPU分配给模型训练,导致整体性能欠佳。

图3 跨累计RL训练时间,我们展示了AREAL-DTA和AREAL在TauBench数据集和1.7B/8B模型上的奖励比较。
图3 跨累计RL训练时间,我们展示了AREAL-DTA和AREAL在TauBench数据集和1.7B/8B模型上的奖励比较。

5.3 消融研究(RQ2)

反向传播优化的消融研究

图4(吞吐量评估)和图5(GPU内存评估)展示了AREAL-DTA前向传播优化的消融研究。我们计算了梯度计算步骤的时间和内存消耗,不包括优化器状态更新。

图4 AREAL-DTA在τ2-Bench数据集和1.7B/4B/8B/14B模型上的训练吞吐量消融研究。
图4 AREAL-DTA在τ2-Bench数据集和1.7B/4B/8B/14B模型上的训练吞吐量消融研究。
图5 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上的反向传播内存利用率消融研究。
图5 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上的反向传播内存利用率消融研究。

数据并行均衡算法的消融研究

图6、7、8展示了我们精心设计的负载均衡算法的消融研究。我们将我们的方法与先前工作中常用的贪心数据并行均衡算法【索引2, Wei Fu, et al. Areal: A large-scale asynchronous reinforcement learning system for language reasoning. arXiv preprint arXiv:2505.24298, 2025.】进行了比较。这个贪心基线迭代地将训练数据分配给当前计算工作量最小的工作节点。我们进行了两组实验,分别使用(i)常规的总序列token数和(ii)为AREAL-DTA量身定制的C(T)来衡量计算工作量。相比之下,我们的数据并行均衡算法根据前缀树T的DFS顺序对序列进行排序,然后基于C(T)进行均匀的连续划分。与基线相比,这减少了由划分产生的额外树token数量,从而降低了计算开销和墙钟时间。实验结果显示,禁用AREAL-DTA的数据并行均衡算法会导致系统性能下降11.93%。

图6 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上使用2个GPU的反向传播吞吐量消融研究。
图6 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上使用2个GPU的反向传播吞吐量消融研究。
图7 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上使用4个GPU的反向传播吞吐量消融研究。
图7 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上使用4个GPU的反向传播吞吐量消融研究。
图8 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上使用8个GPU的反向传播吞吐量消融研究。
图8 AREAL-DTA在TauBench数据集和1.7B/4B/8B/14B模型上使用8个GPU的反向传播吞吐量消融研究。

A5 结论

RL后训练的效率可以通过减少冗余计算来提高,因为流行的工作流可能会生成许多共享长前缀的rollout轨迹。在本文中,我们提出了AREAL-DTA,一个利用前缀共享来提高RL训练效率和可扩展性的系统。AREAL-DTA将rollout组织成一个前缀树,并执行动态DFS遍历,重用共享前缀的计算,同时交错进行前向和后向传播,使活动计算状态的大小受限于最长的根到叶路径,而不是总的树大小。为了在异步RL框架中扩展到多GPU训练,AREAL-DTA进一步引入了一种负载均衡的调度策略,该策略将rollout分批成多个前缀树,并将它们分配到训练GPU上,以在保持前缀重用的同时最小化GPU空闲时间。经验证,AREAL-DTA带来了显著的端到端收益,将训练吞吐量提高了高达8.31倍,使得在相同的硬件预算下,每个提示可以有更大的组大小和更多的rollout。