EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees

A1 主要贡献

本文旨在解决现代大型语言模型(LLM)推理成本高昂且耗时的问题。投机采样(speculative sampling)是一种有效的解决方案,其通过快速生成草稿令牌(draft tokens)然后并行验证来加速推理。

A2 预备知识与关键观察

2. 预备知识

2.1 投机采样

投机采样的核心思想是“先草稿,后验证”。它首先快速生成一段可能正确的草稿,然后检查草稿中有哪些令牌可以被接受。我们使用 $t_i$ 表示第 $i$ 个令牌,用 $T_{a:b}$ 表示令牌序列 $t_a, t_{a+1}, \dots, t_b$。投机采样在草稿阶段和验证阶段之间交替进行。

草稿与验证过程。给定一个前缀 $T_{1:j}$,在草稿阶段,投机采样调用一个草稿模型(比原始LLM小的模型)自回归地生成一个草稿 $T_{\hat{j+1}:j+k}$,同时记录每个令牌的概率 $\hat{p}$。在验证阶段,投机采样调用原始LLM来检查草稿 $T_{\hat{j+1}:j+k}$ 并记录其概率 $p$。接着,投机采样从前到后依次决定草稿令牌的接受与否。对于令牌 $\hat{t}_{j+i}$,它被接受的概率是 $min(1, p_{j+i}(\hat{t}_{j+i}) / \hat{p}_{j+i}(\hat{t}_{j+i}))$。如果令牌被接受,则继续检查下一个;否则,从分布 $norm(max(0, p_{j+i} - \hat{p}_{j+i}))$ 中采样一个令牌来替换 $\hat{t}_{j+i}$,并丢弃草稿中剩余的令牌。文献【22, Fast inference from transformers via speculative decoding, 2023, ICML】的附录A.1证明了投机采样与原始自回归解码的分布是一致的。EAGLE和EAGLE-2都应用了这个框架。

2.2 EAGLE

EAGLE是对投机采样的改进。在本文提交时,EAGLE在Spec-Bench【44, Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding, 2024】上排名第一,这是一个为评估各种场景下投机解码方法而设计的综合基准。
- 草稿阶段(Drafting Stage)。与标准的投机采样自回归地预测令牌序列不同,EAGLE在结构性更强的特征(LM head之前)层面进行自回归,然后使用原始LLM的LM Head来获得草稿令牌。由于采样过程给特征序列带来了不确定性,为了解决这个问题,EAGLE还将提前一个时间步的令牌序列输入到草稿模型中,如图3a所示。
- 验证阶段(Verification Stage)。在标准的投机采样中,草稿是链式结构的,如果一个草稿令牌被拒绝,就需要丢弃所有后续的令牌。EAGLE使用树状结构的草稿,如果一个草稿令牌被拒绝,可以尝试其他分支。图3b展示了两者的区别。

图3:标准投机采样与EAGLE的比较。为简化起见,EAGLE的树状草稿仅在验证阶段展示,而草稿阶段的图示使用了链式结构的草稿。这里,ti表示第i个令牌嵌入,fi表示LLM倒数第二层在LM head之前的第i个特征向量。
图3:标准投机采样与EAGLE的比较。为简化起见,EAGLE的树状草稿仅在验证阶段展示,而草稿阶段的图示使用了链式结构的草稿。这里,ti表示第i个令牌嵌入,fi表示LLM倒数第二层在LM head之前的第i个特征向量。

EAGLE与EAGLE-2的区别。EAGLE的草稿树形状是固定的,草稿阶段只是填充相应的位置。EAGLE-2旨在通过引入一个可动态调整的草稿树来改进这一点。图4通过一个简单的例子说明了EAGLE和EAGLE-2之间的区别。

图4:EAGLE与EAGLE-2的区别。EAGLE总是使用固定的草稿形状。当查询是“10+2=”时,下一个令牌很可能被正确预测为“1”。然而,使用静态草稿树,EAGLE仍会添加两个候选词,即使另一个候选词“3”正确的概率非常低。而EAGLE-2则会根据上下文调整草稿树的形状。当查询是“10+2”时,下一个令牌很难预测,所以EAGLE-2添加了两个候选词。对于更简单的查询“10+2=”,EAGLE-2只添加了一个候选词“1”。
图4:EAGLE与EAGLE-2的区别。EAGLE总是使用固定的草稿形状。当查询是“10+2=”时,下一个令牌很可能被正确预测为“1”。然而,使用静态草稿树,EAGLE仍会添加两个候选词,即使另一个候选词“3”正确的概率非常低。而EAGLE-2则会根据上下文调整草稿树的形状。当查询是“10+2”时,下一个令牌很难预测,所以EAGLE-2添加了两个候选词。对于更简单的查询“10+2=”,EAGLE-2只添加了一个候选词“1”。

3. 观察

3.1 接受率依赖于上下文

首先评估使用动态草稿树的必要性。这取决于草稿令牌的接受率是否仅与它们的位置有关。我们在Alpaca数据集和Vicuna 7B上测试了草稿树中不同位置令牌的接受率。结果如图5所示。总体而言,草稿令牌的接受率与位置相关,位置P1的接受率最高,P6的最低。草稿树左上侧的草稿令牌(如位置P1)接受率较高,而右下侧的(如位置P6)接受率较低。这支持了像EAGLE和Medusa等方法中使用的静态草稿树在左上侧节点更多、右下侧节点更少的合理性。然而,我们同时也观察到在同一位置,接受率存在显著的方差,这表明一个草稿令牌被接受的概率不仅取决于其位置,还取决于上下文。这说明一个上下文感知的动态草稿树比静态草稿树具有更大的潜力。

图5:草稿令牌在不同位置的接受率。左图中,P1-P6表示令牌树中的位置,对应右图横轴上的位置1-6。右图显示了P1-P6位置草稿令牌的接受率。(a) 草稿树结构。(b) 不同位置令牌的接受率,每个点代表一个查询。
图5:草稿令牌在不同位置的接受率。左图中,P1-P6表示令牌树中的位置,对应右图横轴上的位置1-6。右图显示了P1-P6位置草稿令牌的接受率。(a) 草稿树结构。(b) 不同位置令牌的接受率,每个点代表一个查询。

3.2 良好校准的草稿模型

要应用动态草稿树,需要一种低成本的方法来估计草稿令牌的接受率,而无需调用原始LLM。我们在Alpaca数据集上进行了实验,探索草稿模型的置信度分数(LLM对每个令牌输出的概率)与接受率之间的关系。如图6所示,草稿模型的置信度分数与令牌的接受率之间存在很强的正相关关系。置信度分数低于0.05的草稿令牌,其接受率约为0.04;而置信度分数高于0.95的草稿令牌,其接受率约为0.98。因此,我们可以使用草稿模型的置信度分数来估计接受率,而无需额外开销,从而实现对草稿树的动态调整。在其他方法中,如GLIDE和CAPE【9, Glide with a cape: A low-hassle method to accelerate speculative decoding, 2024】,也观察到了类似的现象。

图6:草稿模型不同置信度分数区间的平均接受率。红色虚线连接(0,0)和(1,1)以辅助视觉评估。原始LLM为Vicuna 7B。
图6:草稿模型不同置信度分数区间的平均接受率。红色虚线连接(0,0)和(1,1)以辅助视觉评估。原始LLM为Vicuna 7B。

A3 方法细节

4. 上下文感知的动态草稿树

基于上述观察,我们引入了EAGLE-2。这是一种用于LLM推理的加速算法,它能动态地调整草稿树。EAGLE-2不改变草稿模型的训练和推理,也不影响验证阶段。它的改进集中在两个方面:如何扩展草稿树(第4.1节)以及如何对草稿令牌进行重排序(第4.2节)。在扩展阶段,我们将草稿树最新一层中最有希望的节点输入到草稿模型中以形成下一层。在重排序阶段,我们选择接受概率更高的令牌来构成验证阶段原始LLM的输入。在草稿树中,一个节点代表一个令牌。在下文中,我们交替使用“节点”和“令牌”。

4.1 扩展阶段

利用树注意力机制并行扩展节点。得益于树注意力(tree attention),草稿模型可以同时输入当前层的所有令牌,并在一次前向传播中计算出下一个令牌的概率,从而扩展当前层的所有令牌。然而,一次性输入过多的令牌会减慢草稿模型的前向传播速度,并且草稿树每层令牌的数量呈指数级增长。因此,我们需要有选择地扩展草稿树。

依据全局接受概率选择扩展节点。我们选择当前层中全局接受概率最高的top-k个令牌进行扩展。在投机采样中,拒绝一个草稿令牌会导致丢弃所有后续令牌;一个令牌最终被接受,当且仅当它的所有前缀都被接受。一个令牌 $t_i$ 的全局接受率是其从根节点到 $t_i$ 路径上所有令牌接受率的乘积。我们将其定义为值 $V_i$:

其中,$Path(root, t_i)$ 表示草稿树中从根节点到节点 $t_i$ 的路径,$p_j$ 表示节点 $t_j$ 的接受率,$c_j$ 表示草稿模型对 $t_j$ 的置信度分数。第3.2节的实验表明,置信度分数与接受率高度正相关。我们利用这种关系来近似该值。

扩展流程。从值更高的令牌开始的分支更有可能被接受。因此,我们选择最后一层中值最高的top-k个节点作为草稿模型的输入,并根据输出扩展草稿树。图7的上半部分展示了扩展阶段。

4.2 重排序阶段

重排序以选择全局最优的草稿令牌。扩展阶段的目的是加深草稿树。由于接受率在0到1之间,更深的令牌其值会更低。一些未被扩展的浅层节点可能比更深的已扩展节点具有更高的值。因此,我们不直接使用扩展阶段选择的令牌作为草稿,而是对所有草稿令牌进行重排序,并选择值最高的top-m个令牌。一个节点的值总是小于或等于其父节点的值。对于值相同的节点,我们优先选择更浅的节点。这确保了重排序后选择的top-m个令牌仍然构成一个连通的树。

构造验证阶段的输入与注意力掩码。之后,我们将选定的令牌展平为一维序列,作为验证阶段的输入。为了确保与原始自回归解码的一致性,我们还需要调整注意力掩码。在原始自回归解码中,每个令牌可以看到所有前面的令牌,从而形成一个下三角注意力矩阵。当使用草稿树时,来自不同分支的令牌不应该互相看到。因此,必须根据树结构调整注意力掩码,以确保每个令牌只能看到其祖先节点。图7的下半部分展示了重排序阶段。

图7:EAGLE-2示意图。边旁边的数字代表草稿模型的置信度分数,块内括号中的数字代表节点的值。在扩展阶段,我们从当前层中选择值最高的2个节点(橙色块)作为草稿模型的输入,并将生成的令牌(绿色块)连接到草稿树上。在重排序阶段,我们从所有节点中选择值最高的8个节点(蓝色块),将它们展平为一维序列以形成最终草稿。然后我们根据树结构构造注意力掩码,确保每个令牌只能看到它的祖先节点。
图7:EAGLE-2示意图。边旁边的数字代表草稿模型的置信度分数,块内括号中的数字代表节点的值。在扩展阶段,我们从当前层中选择值最高的2个节点(橙色块)作为草稿模型的输入,并将生成的令牌(绿色块)连接到草稿树上。在重排序阶段,我们从所有节点中选择值最高的8个节点(蓝色块),将它们展平为一维序列以形成最终草稿。然后我们根据树结构构造注意力掩码,确保每个令牌只能看到它的祖先节点。

A4 实验

实验环境

实验结果

5.1 有效性

表1:不同方法的加速比和平均接受长度τ。V代表Vicuna,L2代表LLaMA2-Chat。SpS表示标准投机采样,其草稿模型为Vicuna-68M。像Medusa这样的方法在非贪婪设置下放宽了接受条件,这不能保证无损加速。因此,我们不将EAGLE-2与这些方法进行比较。
表1:不同方法的加速比和平均接受长度τ。V代表Vicuna,L2代表LLaMA2-Chat。SpS表示标准投机采样,其草稿模型为Vicuna-68M。像Medusa这样的方法在非贪婪设置下放宽了接受条件,这不能保证无损加速。因此,我们不将EAGLE-2与这些方法进行比较。

表2:以LLaMA2-Chat 70B、LLaMA3-Instruct 70B和LLaMA3-Instruct 8B为原始LLM,在MT-bench数据集上,温度设置为0时的加速比和平均接受长度τ。
表2:以LLaMA2-Chat 70B、LLaMA3-Instruct 70B和LLaMA3-Instruct 8B为原始LLM,在MT-bench数据集上,温度设置为0时的加速比和平均接受长度τ。

5.2 消融研究

表3:在Vicuna 7B上,温度设置为0时的消融实验结果。“w/o value”表示不使用价值而直接使用置信度,“w/o reranking”表示不进行重排序,“w/o both”表示既不使用价值也不进行重排序。
表3:在Vicuna 7B上,温度设置为0时的消融实验结果。“w/o value”表示不使用价值而直接使用置信度,“w/o reranking”表示不进行重排序,“w/o both”表示既不使用价值也不进行重排序。

A5 结论

本文介绍了EAGLE-2,一种高效且无损的投机采样方法。我们发现EAGLE的草稿模型置信度能够很好地近似草稿令牌的接受率。基于此,EAGLE-2采用了一种依赖于上下文的草稿树结构,显著增加了被接受的草稿令牌数量,从而带来了更优的加速比。EAGLE-2确保了生成结果与原始LLM一致,并且不需要额外的训练。我们在多种LLM和多个数据集上进行了广泛的评估,并将EAGLE-2与几种最先进的投机采样方法进行了比较。在所有的实验中,EAGLE-2都取得了最高的加速比。

A6 附录

A. 实现细节