文章标题:EAGLE: 推测性采样需要重新思考特征不确定性
作者/机构:Yuhui Li (北京大学), Fangyun Wei (微软研究院), Chao Zhang (北京大学), Hongyang Zhang (滑铁卢大学, Vector Institute)
代码链接: https://github.com/SafeAILab/EAGLE
核心问题: 自回归解码是大型语言模型(LLMs)的通用标准,它逐个生成token,导致生成过程缓慢且成本高昂。基于推测性采样的方法通过将过程分为低成本的草稿阶段和并行化的验证阶段来解决此问题,但面临着寻找一个既能模拟原始LLM功能又延迟更低的合适草稿模型的挑战。例如,对于7B参数的模型,很难找到合适的草稿模型;而使用一个7B模型作为13B或70B模型的草稿,其自身的开销也会削弱加速效果。此外,现有的一些低开销草稿方法(如Lookahead和Medusa)虽然降低了延迟,但其草稿准确率较低,限制了性能。
研究目标: 本文旨在提出一种新的高效推测性采样框架EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency),以克服现有方法的局限性,在不改变原始LLM输出分布的前提下,显著提升LLM的推理速度。
创新点:
本文基于两个关键观察提出EAGLE框架:
1. 特征层面的自回归比Token层面更简单:本文中的“特征”指LLM倒数第二层的隐藏状态(LM head之前)。与token序列相比,特征序列表现出更强的规律性。在特征层面进行自回归处理,然后使用原始LLM的LM head导出token,比直接在token层面进行自回归预测能取得更好的效果。
2. 采样过程中的不确定性限制了特征预测的性能:在文本生成中,目标LLM会预测一个token分布并据此进行采样,这引入了随机性。由于特征是高维连续的,无法像token一样处理。例如,在“I”之后可能采样到“am”或“always”,导致后续的特征序列出现分支,给特征级别的自回归带来歧义。
- EAGLE的解决方案:为了解决这种不确定性,EAGLE将一个时间步上提前的token序列(包含了采样结果)作为输入引入到草稿模型中。例如,在预测f_always时,输入为f_I和t_always。这一设计有效消除了歧义,使得模型能够以极小的开销精确预测倒数第二层的特征。
核心优势:
* 通用性: EAGLE原则上适用于任何自回归LLM,已在LLaMA2-Chat、Vicuna和Mixtral 8x7B Instruct等模型上进行了验证。它仅向LLM添加一个轻量级插件(单个Transformer解码器层),易于部署。
* 可靠性: EAGLE不微调原始LLM,并理论上保证在贪婪和非贪婪设置下都能维持输出分布不变,这与Lookahead(仅限贪婪设置)和Medusa(非贪婪设置不保证无损)形成对比。
* 低训练成本: 训练EAGLE的成本很低。例如,为LLaMA2-Chat 70B模型训练一个小于1B参数的解码器层,仅需7万条对话数据,在4块A100(40G)GPU上1-2天即可完成。
符号定义: 本文中,“目标LLM”指需要加速的LLM,“草稿模型”指用于生成草稿的模型。“特征”(feature)通常指LLM倒数第二层的特征,即LM head之前的隐藏状态。Token用小写t表示,其嵌入用e表示,特征用f表示,分布用p表示。序列用大写字母表示,例如T_i:j代表(t_i, t_i+1, ..., t_j)。在一个LLM中,输入T_1:j通过嵌入层转换为嵌入E_1:j,然后转换为特征F_1:j,LM Head将f_j映射到一个分布p_{j+1} = LM Head(f_j),并从中采样下一个token t_{j+1}。原始的token级自回归过程描述为T_1:j → E_1:j → f_j → p_{j+1} → t_{j+1}。
推测性采样: 推测性采样通过草稿和验证两个阶段进行操作。草稿阶段使用一个较小的模型生成γ个候选tokens T̂_{j+1:j+γ}及其分布P̂_{j+1:j+γ}。在验证阶段,目标LLM通过一次前向传播计算出概率P_{j+1:j+γ}。然后,tokens被逐个评估,一个token t̂_{j+i}的接受概率为min(1, p_{j+i}(t̂_{j+i}) / p̂_{j+i}(t̂_{j+i}))。当一个token t̂_{j+i}被拒绝时,其后的所有tokens都会被丢弃,并且这个token会基于一个调整后的分布norm(max(0, p_{j+i} - p̂_{j+i}))进行重采样。正如推测性采样论文【16, Fast inference from transformers via speculative decoding, 2023, ICML】的附录A.1所证明的,该方法等同于直接从目标LLM进行采样。EAGLE采纳了此方法,确保在贪婪和非贪婪设置下生成文本的分布保持不变。
EAGLE与其他基于推测性采样的方法一样,包含草稿阶段和验证阶段。
EAGLE草稿阶段的核心机制: EAGLE与其他方法的主要区别在于草稿阶段。如图5所示,标准的推测性采样【16, Fast inference from transformers via speculative decoding, 2023, ICML; 2, Accelerating large language model decoding with speculative sampling, 2023, arXiv】和Lookahead【6, Breaking the sequential dependency of LLM inference using lookahead decoding, 2023, lmsys.org blog】是基于token预测token。Medusa【1, Medusa: Simple framework for accelerating LLM generation with multiple decoding heads, 2023, GitHub】则利用目标LLM的特征f2独立地预测t4和t5。相比之下,EAGLE使用特征序列(f1, f2)和一个提前一个时间步的token序列(t2, t3)来预测下一个特征f3。然后,从p4 = LM Head(f3)中采样出t4。接着,将f3和t4拼接到输入序列中,用于预测下一个特征f4并采样出随后的token t5。
EAGLE草稿模型的架构: 如图6所示,EAGLE的草稿模型由三个模块组成:嵌入层(Embedding layer)、LM Head和自回归头(Autoregression Head)。嵌入层和LM Head直接使用目标LLM的参数,无需额外训练。草稿模型接收一个形状为(bs, seq_len, hidden_dim)的特征序列和一个形状为(bs, seq_len)的提前token序列作为输入。它首先将token序列转换为形状为(bs, seq_len, hidden_dim)的token嵌入序列,然后将两者拼接成一个形状为(bs, seq_len, 2×hidden_dim)的融合序列。自回归头由一个FC层和一个解码器层(decoder layer)组成。FC层将融合序列的维度降至(bs, seq_len, hidden_dim),然后利用解码器层来预测下一个特征。LM Head根据这个预测的特征计算出分布,并从中采样下一个token。最后,预测出的特征和采样到的token被拼接到输入中,以继续自回归过程。EAGLE利用树状注意力(tree attention)来创建一个树状结构的草稿,通过m次前向传播生成一个深度为m且包含超过m个token的草稿树。例如,如图6所示,EAGLE仅用3次前向传播就生成了一个包含10个token的树。EAGLE使用的具体树结构在附录A.1中有详细说明。
损失函数设计: 预测下一个特征是一个回归任务,为此我们使用Smooth L1损失(如图5 EAGLE所示):L_reg = Smooth L1(f_{i+1}, Draft Model(T_{2:i+1}, F_{1:i}))。预测特征是草稿模型的中间目标,最终目标是预测token以生成token序列。因此,我们也使用分类损失来直接优化这个最终目标。
通过结合回归损失和分类损失,我们使用组合损失函数$L = L_{reg} + w_{cls}L_{cls}$来训练自回归头。通常,分类损失在数值上比回归损失大一个数量级,因此我们将w_cls设置为0.1。
训练数据与策略: EAGLE的自回归头最理想的训练数据是目标LLM自回归生成的文本,但这种方法成本高昂。幸运的是,EAGLE对训练数据不敏感(见4.3.3节的消融研究)。我们使用一个固定的数据集代替目标LLM生成的文本,从而大大降低了开销。在草稿阶段,EAGLE自回归地处理特征。特征中的不准确性可能导致误差累积。为了缓解这个问题,我们采用了数据增强技术,在训练期间向目标LLM的特征添加从均匀分布$U(-0.1, 0.1)$中采样的随机噪声【12, NEFTune: Noisy embeddings improve instruction finetuning, 2023, arXiv】。
树状验证过程: 利用树状注意力,目标LLM通过单次前向传播计算出树状结构草稿中每个token的概率。在草稿树的每个节点上,我们递归地应用推测性采样算法来采样或调整分布(详见附录A.2),这与SpecInfer【18, SpecInfer: Accelerating generative LLM serving with speculative inference and token tree verification, 2023, arXiv】的做法一致,确保输出文本的分布与目标LLM的分布完全对齐。同时,我们会记录被接受的tokens及其对应的特征,以供下一轮草稿阶段使用。
(β1, β2) = (0.9, 0.95)。n-α表明,无错误的特征序列(0-α)的接受率远高于有错误的序列(1-α),但1-α到4-α之间的变化不大,说明EAGLE对特征错误具有鲁棒性,能有效处理误差累积。EAGLE可与其他加速技术(如量化和编译)兼容。实验将EAGLE与gpt-fast结合,在单张RTX 3090上将LLaMA2-Chat 7B的生成速度提升至160.4 tokens/s(表4)。
τ增加了约0.6-0.8,加速比提升了约0.3-0.5。即使不使用树状注意力,EAGLE仍能达到2.3x-2.7x的显著加速效果。feature&shifted-token (EAGLE)、feature&unshifted-token、token和feature。feature比token效果稍好。feature和token(feature&unshifted-token)性能有适度提升,主要是因为无错误的token缓解了特征误差累积。feature&shifted-token(EAGLE)方案通过将token提前一个时间步,使草稿模型能够考虑采样的随机性,从而显著提升了性能,而没有增加任何复杂性。本文介绍了EAGLE,一个高效的推测性采样框架。EAGLE的核心思想是在结构性更强的倒数第二层特征层面进行自回归草稿生成,并通过引入提前一个时间步的token来解决下一特征预测中的采样不确定性问题。EAGLE在显著提升生成速度的同时,保证了LLM输出分布的不变性。在MT-bench上,EAGLE比原始自回归解码快2.1x-3.8x,比Lookahead快1.7x-2.1x,比Medusa快1.5x-1.6x。
EAGLE的草稿树: EAGLE利用树状注意力生成树状结构的草稿。图9左侧展示了草稿的树结构,而右侧则描绘了不使用树状注意力时的对应链式结构(用于4.3.1节的消融研究)。在贪婪设置中,我们选择概率最高的top-k个token作为子节点。在非贪婪设置中,我们采样k个token。子节点的数量k可以从图9中推断;例如,根节点处的k=4。无论采用树状还是链式草稿,草稿模型在草稿阶段都进行5次前向传播。在验证阶段,目标LLM通过一次前向传播获得每个token的概率。
树结构选择的依据: 图9中所示的树结构的选择并非经过严格优化,而是基于直觉:概率更高的token所在的分支应该更深、更宽。在本文中,所有模型在所有实验中都使用了图9所示的草稿结构。然而,最优的树结构可能与上下文相关。例如,随着批量大小增加和冗余计算资源减少,一个更小的树可能更可取。调整草稿结构可能会带来性能提升。
树状草稿的采样算法: 与推测性采样的链式草稿不同,EAGLE采用树状草稿,这需要对采样算法进行修改。推测性采样的采样算法A可以简述为:如果一个token被接受,则返回该token;否则,从调整后的分布中采样一个token。对于一个有k个候选token的树状草稿,多轮推测性采样(Multi-round speculative sampling)会递归地调用算法A。在拒绝一个token后,它不会直接从调整后的分布中采样,而是再次调用A。如果所有token都被拒绝,它才会直接从调整后的分布中采样。多轮推测性采样的伪代码在算法1中提供。
表8展示了EAGLE在HumanEval、GSM8K和Alpaca数据集上,在temperature = 0时的加速比、平均接受长度 τ 和接受率 α 的详细数据。