EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
文章标题:EAGLE: 推测性采样需要重新思考特征不确定性
作者/机构:Yuhui Li (北京大学), Fangyun Wei (微软研究院), Chao Zhang (北京大学), Hongyang Zhang (滑铁卢大学, Vector Institute)
代码链接: https://github.com/SafeAILab/EAGLE
A1 主要贡献
核心问题: 自回归解码是大型语言模型(LLMs)的通用标准,它逐个生成token,导致生成过程缓慢且成本高昂。基于推测性采样的方法通过将过程分为低成本的草稿阶段和并行化的验证阶段来解决此问题,但面临着寻找一个既能模拟原始LLM功能又延迟更低的合适草稿模型的挑战。例如,对于7B参数的模型,很难找到合适的草稿模型;而使用一个7B模型作为13B或70B模型的草稿,其自身的开销也会削弱加速效果。此外,现有的一些低开销草稿方法(如Lookahead和Medusa)虽然降低了延迟,但其草稿准确率较低,限制了性能。
研究目标: 本文旨在提出一种新的高效推测性采样框架EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency),以克服现有方法的局限性,在不改变原始LLM输出分布的前提下,显著提升LLM的推理速度。
创新点:
本文基于两个关键观察提出EAGLE框架:
1. 特征层面的自回归比Token层面更简单:本文中的“特征”指LLM倒数第二层的隐藏状态(LM head之前)。与token序列相比,特征序列表现出更强的规律性。在特征层面进行自回归处理,然后使用原始LLM的LM head导出token,比直接在token层面进行自回归预测能取得更好的效果。
2. 采样过程中的不确定性限制了特征预测的性能:在文本生成中,目标LLM会预测一个token分布并据此进行采样,这引入了随机性。由于特征是高维连续的,无法像token一样处理。例如,在“I”之后可能采样到“am”或“always”,导致后续的特征序列出现分支,给特征级别的自回归带来歧义。
- EAGLE的解决方案:为了解决这种不确定性,EAGLE将一个时间步上提前的token序列(包含了采样结果)作为输入引入到草稿模型中。例如,在预测f_always时,输入为f_I和t_always。这一设计有效消除了歧义,使得模型能够以极小的开销精确预测倒数第二层的特征。
核心优势:
* 通用性: EAGLE原则上适用于任何自回归LLM,已在LLaMA2-Chat、Vicuna和Mixtral 8x7B Instruct等模型上进行了验证。它仅向LLM添加一个轻量级插件(单个Transformer解码器层),易于部署。
* 可靠性: EAGLE不微调原始LLM,并理论上保证在贪婪和非贪婪设置下都能维持输出分布不变,这与Lookahead(仅限贪婪设置)和Medusa(非贪婪设置不保证无损)形成对比。
* 低训练成本: 训练EAGLE的成本很低。例如,为LLaMA2-Chat 70B模型训练一个小于1B参数的解码器层,仅需7万条对话数据,在4块A100(40G)GPU上1-2天即可完成。
A3 背景知识
符号定义: 本文中,“目标LLM”指需要加速的LLM,“草稿模型”指用于生成草稿的模型。“特征”(feature)通常指LLM倒数第二层的特征,即LM head之前的隐藏状态。Token用小写t表示,其嵌入用e表示,特征用f表示,分布用p表示。序列用大写字母表示,例如T_i:j代表(t_i, t_i+1, ..., t_j)。在一个LLM中,输入T_1:j通过嵌入层转换为嵌入E_1:j,然后转换为特征F_1:j,LM Head将f_j映射到一个分布p_{j+1} = LM Head(f_j),并从中采样下一个token t_{j+1}。原始的token级自回归过程描述为T_1:j → E_1:j → f_j → p_{j+1} → t_{j+1}。
推测性采样: 推测性采样通过草稿和验证两个阶段进行操作。草稿阶段使用一个较小的模型生成γ个候选tokens T̂_{j+1:j+γ}及其分布P̂_{j+1:j+γ}。在验证阶段,目标LLM通过一次前向传播计算出概率P_{j+1:j+γ}。然后,tokens被逐个评估,一个token t̂_{j+i}的接受概率为min(1, p_{j+i}(t̂_{j+i}) / p̂_{j+i}(t̂_{j+i}))。当一个token t̂_{j+i}被拒绝时,其后的所有tokens都会被丢弃,并且这个token会基于一个调整后的分布norm(max(0, p_{j+i} - p̂_{j+i}))进行重采样。正如推测性采样论文【16, Fast inference from transformers via speculative decoding, 2023, ICML】的附录A.1所证明的,该方法等同于直接从目标LLM进行采样。EAGLE采纳了此方法,确保在贪婪和非贪婪设置下生成文本的分布保持不变。
A2 方法细节
EAGLE与其他基于推测性采样的方法一样,包含草稿阶段和验证阶段。
3.1 草稿阶段
EAGLE草稿阶段的核心机制: EAGLE与其他方法的主要区别在于草稿阶段。如图5所示,标准的推测性采样【16, Fast inference from transformers via speculative decoding, 2023, ICML; 2, Accelerating large language model decoding with speculative sampling, 2023, arXiv】和Lookahead【6, Breaking the sequential dependency of LLM inference using lookahead decoding, 2023, lmsys.org blog】是基于token预测token。Medusa【1, Medusa: Simple framework for accelerating LLM generation with multiple decoding heads, 2023, GitHub】则利用目标LLM的特征f2独立地预测t4和t5。相比之下,EAGLE使用特征序列(f1, f2)和一个提前一个时间步的token序列(t2, t3)来预测下一个特征f3。然后,从p4 = LM Head(f3)中采样出t4。接着,将f3和t4拼接到输入序列中,用于预测下一个特征f4并采样出随后的token t5。
EAGLE草稿模型的架构: 如图6所示,EAGLE的草稿模型由三个模块组成:嵌入层(Embedding layer)、LM Head和自回归头(Autoregression Head)。嵌入层和LM Head直接使用目标LLM的参数,无需额外训练。草稿模型接收一个形状为(bs, seq_len, hidden_dim)的特征序列和一个形状为(bs, seq_len)的提前token序列作为输入。它首先将token序列转换为形状为(bs, seq_len, hidden_dim)的token嵌入序列,然后将两者拼接成一个形状为(bs, seq_len, 2×hidden_dim)的融合序列。自回归头由一个FC层和一个解码器层(decoder layer)组成。FC层将融合序列的维度降至(bs, seq_len, hidden_dim),然后利用解码器层来预测下一个特征。LM Head根据这个预测的特征计算出分布,并从中采样下一个token。最后,预测出的特征和采样到的token被拼接到输入中,以继续自回归过程。EAGLE利用树状注意力(tree attention)来创建一个树状结构的草稿,通过m次前向传播生成一个深度为m且包含超过m个token的草稿树。例如,如图6所示,EAGLE仅用3次前向传播就生成了一个包含10个token的树。EAGLE使用的具体树结构在附录A.1中有详细说明。
3.2 草稿模型的训练
损失函数设计: 预测下一个特征是一个回归任务,为此我们使用Smooth L1损失(如图5 EAGLE所示):L_reg = Smooth L1(f_{i+1}, Draft Model(T_{2:i+1}, F_{1:i}))。预测特征是草稿模型的中间目标,最终目标是预测token以生成token序列。因此,我们也使用分类损失来直接优化这个最终目标。
通过结合回归损失和分类损失,我们使用组合损失函数$L = L_{reg} + w_{cls}L_{cls}$来训练自回归头。通常,分类损失在数值上比回归损失大一个数量级,因此我们将w_cls设置为0.1。
训练数据与策略: EAGLE的自回归头最理想的训练数据是目标LLM自回归生成的文本,但这种方法成本高昂。幸运的是,EAGLE对训练数据不敏感(见4.3.3节的消融研究)。我们使用一个固定的数据集代替目标LLM生成的文本,从而大大降低了开销。在草稿阶段,EAGLE自回归地处理特征。特征中的不准确性可能导致误差累积。为了缓解这个问题,我们采用了数据增强技术,在训练期间向目标LLM的特征添加从均匀分布$U(-0.1, 0.1)$中采样的随机噪声【12, NEFTune: Noisy embeddings improve instruction finetuning, 2023, arXiv】。
3.3 验证阶段
树状验证过程: 利用树状注意力,目标LLM通过单次前向传播计算出树状结构草稿中每个token的概率。在草稿树的每个节点上,我们递归地应用推测性采样算法来采样或调整分布(详见附录A.2),这与SpecInfer【18, SpecInfer: Accelerating generative LLM serving with speculative inference and token tree verification, 2023, arXiv】的做法一致,确保输出文本的分布与目标LLM的分布完全对齐。同时,我们会记录被接受的tokens及其对应的特征,以供下一轮草稿阶段使用。
A4 实验环境
- 模型:
- Vicuna 系列: 7B, 13B, 33B
- LLaMA2-Chat 系列: 7B, 13B, 70B
- MoE 模型: Mixtral 8x7B Instruct
- 数据集与任务:
- MT-bench【41, Judging llm-as-a-judge with mt-bench and chatbot arena, 2023, arXiv】: 用于多轮对话任务评估。
- HumanEval【3, Evaluating large language models trained on code, 2021, arXiv】: 用于代码生成任务评估。
- GSM8K【5, Training verifiers to solve math word problems, 2021, arXiv】: 用于数学推理任务评估。
- Alpaca【29, Stanford alpaca: An instruction-following llama model, 2023, GitHub】: 用于指令遵循任务评估。
- 硬件配置:
- GPU: A100 40G服务器(用于70B模型训练),单张RTX 3090(用于7B、13B、33B模型训练及与gpt-fast结合的测试)。LLaMA2-Chat 70B推理使用4块A100(40G)GPU。
- 软件与实验设置:
- 大多数实验的批量大小(batch size)为1。
- 所有评估均在FP16精度下进行。
- EAGLE在ShareGPT数据集上进行训练,使用68,000次对话迭代。
- 优化器: AdamW,beta值为
(β1, β2) = (0.9, 0.95)。 - 学习率: 3e-5,梯度裁剪阈值为0.5。
- EAGLE的可训练参数量:7B模型为0.24B,13B为0.37B,33B为0.56B,70B为0.99B,Mixtral 8x7B为0.28B。
A4 实验结果
效果评估
- 总体加速效果: 根据图1、图2和表1,EAGLE在temperature=0(贪婪解码)时的加速效果优于temperature=1。例如,LLaMA2-Chat 13B在T=0时的加速比为3.01x-3.76x,而在T=1时为2.66x-2.89x。在代码生成任务(HumanEval)上,EAGLE取得了最佳加速性能,这可能是因为代码中固定的模板更容易生成草稿。
- 与SOTA方法对比: 相比于Lookahead和Medusa,EAGLE的速度分别提升了1.70x-2.08x和1.47x-1.60x。与传统的推测性采样相比,EAGLE也表现出明显优势,后者在13B模型上无加速,在33B和70B模型上仅有1.12x和1.88x的加速。使用与EAGLE相同数据训练的DistillSpec虽然有提升,但效果有限,因为其性能瓶颈在于草稿模型的高开销。
- 接受长度与接受率: 表1和表2显示,EAGLE每次前向传播能接受3.2-4.5个token,远超原始解码的1个token。图2和附录B中的接受率
n-α表明,无错误的特征序列(0-α)的接受率远高于有错误的序列(1-α),但1-α到4-α之间的变化不大,说明EAGLE对特征错误具有鲁棒性,能有效处理误差累积。 - MoE模型效果: 在Mixtral 8x7B Instruct模型上,EAGLE实现了1.5倍的加速(表3)。加速效果相对温和的原因是平均接受长度较短,且MoE模型在验证阶段可能需要访问超过两个专家的权重,增加了计算开销。
表1:在HumanEval、GSM8K和Alpaca上的加速比和平均接受长度 τ。T表示温度,V表示Vicuna,LC表示LLaMA2-Chat。
表2:在MT-bench上的平均接受长度 τ 和接受率 α。T表示温度。
表3:在MT-bench上,temperature=0时,目标LLM为Mixtral 8x7B Instruct-v0.1的加速比、平均接受长度 τ 和接受率 α。
案例研究: EAGLE + gpt-fast
EAGLE可与其他加速技术(如量化和编译)兼容。实验将EAGLE与gpt-fast结合,在单张RTX 3090上将LLaMA2-Chat 7B的生成速度提升至160.4 tokens/s(表4)。
消融研究
- 树状注意力: 表5和图7显示,使用树状注意力的草稿和验证机制,相比链式结构,平均接受长度
τ增加了约0.6-0.8,加速比提升了约0.3-0.5。即使不使用树状注意力,EAGLE仍能达到2.3x-2.7x的显著加速效果。
图7:使用和不使用树状注意力时EAGLE的加速比。评估数据集为MT-bench,温度参数设置为0。
表5:使用和不使用树状注意力时EAGLE的平均接受长度 τ。评估数据集为MT-bench,温度参数设置为0。 - 草稿模型输入: 如图8所示,在Vicuna 7B上的消融实验对比了四种输入:
feature&shifted-token(EAGLE)、feature&unshifted-token、token和feature。- 在参数量有限时,使用
feature比token效果稍好。 - 结合
feature和token(feature&unshifted-token)性能有适度提升,主要是因为无错误的token缓解了特征误差累积。 - 最重要的提升来自于解决采样过程中的随机性。
feature&shifted-token(EAGLE)方案通过将token提前一个时间步,使草稿模型能够考虑采样的随机性,从而显著提升了性能,而没有增加任何复杂性。
图8:不同输入的草稿模型性能。目标LLM为Vicuna 7B,测试数据集为MT-bench。Speed指walltime加速比,τ指平均接受长度,0-α指输入完全精确时的接受率,1-α指输入包含一个不精确特征时的接受率,T指温度。
- 在参数量有限时,使用
- 训练数据: 表6的消融研究表明,使用目标LLM生成的数据训练相比使用固定的ShareGPT数据集,性能提升微乎其微。这证明EAGLE对训练数据不敏感,使用固定数据集来降低成本是合理的。
表6:使用不同训练数据集在MT-bench上评估的加速比和平均接受长度 τ,目标LLM为LLaMA2-Chat 7B,温度为0。“固定数据集”指问题和答案均来自ShareGPT数据集。“目标LLM生成的数据”指问题来自ShareGPT,但答案由目标LLM生成。
批量大小与吞吐量
- 加速比随批量大小的变化: LLM推理是内存密集型的,推测性采样利用了未被充分利用的GPU计算资源。如表7所示,随着批量大小增加,可用的计算资源减少,加速效果随之下降。
- 吞吐量: 尽管推测性采样主要关注延迟,本文也研究了吞吐量。EAGLE需要稍多的CUDA内存。在硬件限制下(Vicuna 7B在RTX 3090 24G,LLaMA2-Chat 70B在4xA100 160G),EAGLE的最大批量大小略小于原始解码。然而,通过在最大批量大小下运行,EAGLE实现了约2倍的吞吐量提升。
表7:不同批量大小下的加速比和EAGLE的吞吐量。评估数据集为MT-bench,温度参数设置为0。
A5 结论
本文介绍了EAGLE,一个高效的推测性采样框架。EAGLE的核心思想是在结构性更强的倒数第二层特征层面进行自回归草稿生成,并通过引入提前一个时间步的token来解决下一特征预测中的采样不确定性问题。EAGLE在显著提升生成速度的同时,保证了LLM输出分布的不变性。在MT-bench上,EAGLE比原始自回归解码快2.1x-3.8x,比Lookahead快1.7x-2.1x,比Medusa快1.5x-1.6x。
A6 附录
A.1 树结构
EAGLE的草稿树: EAGLE利用树状注意力生成树状结构的草稿。图9左侧展示了草稿的树结构,而右侧则描绘了不使用树状注意力时的对应链式结构(用于4.3.1节的消融研究)。在贪婪设置中,我们选择概率最高的top-k个token作为子节点。在非贪婪设置中,我们采样k个token。子节点的数量k可以从图9中推断;例如,根节点处的k=4。无论采用树状还是链式草稿,草稿模型在草稿阶段都进行5次前向传播。在验证阶段,目标LLM通过一次前向传播获得每个token的概率。
树结构选择的依据: 图9中所示的树结构的选择并非经过严格优化,而是基于直觉:概率更高的token所在的分支应该更深、更宽。在本文中,所有模型在所有实验中都使用了图9所示的草稿结构。然而,最优的树结构可能与上下文相关。例如,随着批量大小增加和冗余计算资源减少,一个更小的树可能更可取。调整草稿结构可能会带来性能提升。
A.2 多轮推测性采样
树状草稿的采样算法: 与推测性采样的链式草稿不同,EAGLE采用树状草稿,这需要对采样算法进行修改。推测性采样的采样算法A可以简述为:如果一个token被接受,则返回该token;否则,从调整后的分布中采样一个token。对于一个有k个候选token的树状草稿,多轮推测性采样(Multi-round speculative sampling)会递归地调用算法A。在拒绝一个token后,它不会直接从调整后的分布中采样,而是再次调用A。如果所有token都被拒绝,它才会直接从调整后的分布中采样。多轮推测性采样的伪代码在算法1中提供。
B. 详细实验结果
表8展示了EAGLE在HumanEval、GSM8K和Alpaca数据集上,在temperature = 0时的加速比、平均接受长度 τ 和接受率 α 的详细数据。
💬 评论讨论
欢迎在这里分享您的想法和见解!