Accelerating Large Language Model Decoding with Speculative Sampling

作者/机构: Charlie Chen, Sebastian Borgeaud, Geoffrey Irving, Jean-Baptiste Lespiau, Laurent Sifre and John Jumper (All authors from DeepMind)

A1 主要贡献

本文提出了一种名为推测性采样(speculative sampling, SpS)的算法,旨在通过单次Transformer调用生成多个token来加速解码过程。该算法的核心观察是,使用一个速度更快但能力较弱的“草稿模型”(draft model)生成简短的续写,并由更大、更强的“目标模型”(target model)并行地对这些续写进行评分,其延迟与目标模型采样单个token的延迟相当。该方法结合了一种新颖的改进版拒绝采样方案,能够在硬件数值精度范围内精确地保持目标模型的输出分布。

该研究的核心问题是大型Transformer模型(参数量超过5000亿)的自回归解码过程成本高昂且效率低下,其性能主要受内存带宽限制,导致生成单个token的时间与模型参数和内存大小成正比。此外,模型并行化还带来了额外的通信开销和资源需求。

本文的研究目标是提出一种能够有效降低延迟关键型应用中Transformer采样延迟的算法。具体实现步骤如下:
1. 生成草稿:使用一个速度更快的自回归模型(草稿模型)生成一个长度为 K 的简短草稿序列。
2. 并行评分:使用更大、更强的目标模型对这个草稿序列进行并行评分,获取每个token的logits。
3. 拒绝采样:采用一种改进的拒绝采样方案,从左到右逐个接受草稿中的token,从而恢复目标模型的概率分布。

该方法的直觉在于,在许多情况下,序列中的下一个token可能是“显而易见”的。当草稿模型和目标模型的分布在某个token或子序列上高度一致时,该设置允许在每次调用目标模型时生成多个token。

本文的主要创新点在于:
* 提出了一种实用的推测性采样算法,能够在不修改目标模型或影响样本分布的情况下,有效降低大型语言模型的采样延迟。
* 通过实验证明,对于700亿参数的Chinchilla语言模型,在分布式设置下,推测性采样实现了2至2.5倍的解码加速。
* 在某些情况下,使用推测性采样的平均每秒生成token数(tokens per second)甚至超过了由内存带宽限制的自回归采样速度的理论上限。

相关工作

已有大量工作致力于改善大型Transformer和其他自回归模型的采样延迟。

由于采样性能与模型在内存中的大小密切相关,量化到int8甚至int4(【Dettmers et al., 2022, LLM. int8 (): 8-bit matrix multiplication for transformers at scale】;【Yao et al., 2022, Zeroquant: Efficient and affordable post-training quantization for large-scale transformers】)和蒸馏(【Jiao et al., 2020, TinyBERT: Distilling BERT for natural language understanding】;【Sanh et al., 2019, Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter】)是降低采样延迟而性能损失很小或没有损失的有效技术。同时,研究发现模型大小对最终性能的贡献低于预期(【Hoffmann et al., 2022, Training compute-optimal large language models】),这也鼓励了更小语言模型的普遍发展。

在采样过程中,每个注意力层都会维护一个键(keys)和值(values)的缓存,随着批量大小的增加,这可能成为内存带宽的瓶颈。像多查询注意力(multi-query attention)(【Shazeer, 2019, Fast transformer decoding: One write-head is all you need】)这样的方法旨在通过缩小这个缓存来提高采样性能。然而,这些技术在最大化吞吐量(在较大批量下)方面最有效,而不是延迟,特别是对于大部分内存带宽预算被参数消耗的更大型号。

结合上述技术以及对TPU的一系列底层优化,Pope等人(【Pope et al., 2022, Efficiently scaling transformer inference】)极大地改善了PaLM 540B的服务延迟和效率。

已有类似的工作利用了Transformer和序列模型并行操作的效率。这包括块状并行采样(block parallel sampling)(【Stern et al., 2018, Blockwise parallel decoding for deep autoregressive models】)、激进解码(aggressive decoding)(【Ge et al., 2022, Lossless acceleration for seq2seq generation with aggressive decoding】),以及一些在图像领域并行化自回归模型的工作(【Song et al., 2021, Accelerating feedforward computation via parallel nonlinear equation solving】;【Wiggers and Hoogeboom, 2020, Predictive sampling with forecasting autoregressive models】)。这些方法尚未适用于典型的语言模型用例,因为它们要么只适用于贪婪采样,要么会偏离结果,要么专注于其他模态。此外,据我们所知,这些技术都未被扩展到分布式设置,而这对于拥有数百亿或数千亿参数的最昂贵解码器是必需的。

巧合的是,本手稿中的工作与Leviathan等人(【Leviathan et al., 2022, Fast inference from transformers via speculative decoding】)关于推测解码的工作是同期且独立进行的。我们更侧重于大型模型的分布式服务设置,并提供了一些增量优化,但核心的基本思想是相同的。

A3 自回归采样

尽管Transformer可以在TPU和GPU上高效地并行训练,但样本通常是自回归生成的(见算法1)。对于大多数应用,自回归采样(ArS)是高度受内存带宽限制的,因此无法有效利用现代加速器硬件(【Shazeer, 2019, Fast transformer decoding: One write-head is all you need】)。一次受内存限制的模型调用仅为批次中的每个序列生成一个token,因此,生成多个token会在任何使用它的系统中引入大量的延迟。

随着模型参数数量的增加,这个问题尤其严重。由于所有模型参数都需要至少通过一个加速器芯片,模型大小除以所有芯片的总内存带宽,为我们提供了一个自回归采样速度的硬性上限。更大的模型还需要在多个加速器上提供服务,由于设备间的通信开销,这引入了另一个延迟源。

算法1 使用自回归模型进行自回归采样(ArS)

给定自回归目标模型 $q(.|.)$、初始提示序列 $x_1, . . . , x_t$ 和目标序列长度 $T$。
初始化 $n \leftarrow t$。
while $n < T$ do
Sample $x_{n+1} \sim q(x |x_1, . . . , x_n)$
$n \leftarrow n + 1$
end while

算法2 使用自回归目标和草稿模型进行推测性采样(SpS)

A2 方法细节

条件评分

推测性采样的核心观察。对于推测性采样(见算法2),我们首先观察到,并行计算一个长度为 K 的短续写的logits与采样单个token的延迟非常相似。我们关注的是大型Transformer,它们以Megatron风格(【Shoeybi et al., 2019, Megatron-lm: Training multi-billion parameter language models using model parallelism】)进行分片。对于这些模型,大部分采样时间可归因于三个组成部分:

线性层。对于小批量大小,每个线性层只处理少量的嵌入。这导致前馈层、查询、键、值计算以及最终的注意力投影中的密集矩阵乘法变得受内存限制。对于较小的K,这将继续是内存限制的,因此花费的时间相似。

注意力机制。注意力机制也受内存限制。在采样期间,我们为序列中所有先前的token维护一个键和值的缓存(KV-cache),以避免重新计算。这些KV-cache很大,并占据了注意力机制大部分的内存带宽利用。然而,由于KV-cache的大小随着我们增加K而不会改变,因此该组件几乎没有时间增量。

All-reduces通信。随着模型规模的增长,其参数需要被划分到多个加速器上,导致通信开销。使用Megatron时,这表现为在每个前馈和注意力层之后进行一次all-reduce操作。由于只传输少量token的激活值,对于采样和评分(对于较小的K),该操作通常是受延迟限制而非吞吐量限制。同样,这导致两种情况下花费的时间相似。

其他开销。根据具体的Transformer实现,可能存在其他开销来源。因此,位置编码的选择、解码方法(例如,核采样可能需要排序)、硬件限制等仍可能在评分和采样之间引入一些差异。然而,如果满足上述组件占主导地位的条件,那么对于较小的K,评分应该不会明显慢于采样。

改进的拒绝采样

恢复目标分布的机制。我们需要一种方法,从草稿模型的样本以及两个模型对这些token的logits中恢复目标模型的分布。

接受概率。为此,我们引入了以下对草稿token的拒绝采样方案。给定一个token序列 $x_1, . . . , x_n$,以及从 $p(.|.)$ 生成的K个草稿token $\tilde{x}_{n+1}, . . . , \tilde{x}_{n+K}$,我们以以下概率接受 $\tilde{x}_{n+1}$:

$$\min \left(1, \frac{q\left(\tilde{x}_{n+1} \mid x_1, \ldots, x_n\right)}{p\left(\tilde{x}_{n+1} \mid x_1, \ldots, x_n\right)}\right)$$

其中 $q(\tilde{x}_{n+1}|x_1, . . . , x_n)$ 和 $p(\tilde{x}_{n+1}|x_1, . . . , x_n)$ 分别是根据目标模型和草稿模型,在给定当前上下文的情况下,token $\tilde{x}_{n+1}$ 的概率。

接受后的处理。如果token被接受,我们设置 $x_{n+1} \leftarrow \tilde{x}_{n+1}$ 并对 $\tilde{x}_{n+2}$ 重复此过程,直到某个token被拒绝或所有token都被接受。

拒绝后的处理。如果 $\tilde{x}_{n+1}$ 被拒绝,我们从以下分布中重新采样 $x_{n+1}$:

$$x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+$$

其中 $(.)_+$ 表示:

$$ (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} $$

方法的关键属性。通过顺序应用此方法,我们在硬件数值精度范围内恢复了被接受token的目标模型分布(证明见定理1)。注意:
* 一个草稿-接受循环中总会生成至少一个token——如果第一个token被拒绝,会重新采样一个有效的token。
* 由于草稿的最后一个token为我们提供了下一个token的logits,如果所有草稿token都被接受,我们可以从中正常采样。这使得每个循环最多可以生成 K + 1 个token,优于仅返回K个token的朴素实现。

与标准采样方法的兼容性。对于核采样、top-k采样和调整温度等标准采样方法,我们可以在应用此拒绝采样方案之前相应地修改概率。我们观察到,总体的接受率对所使用的具体参数是稳健的。

与其他优化技术的结合。因为我们不与Transformer的主体部分交互,所以该方法可以与许多其他加速或优化采样内存使用的技术结合使用,例如量化和多查询注意力。

草稿模型的选择

选择草稿模型的灵活性。由于接受准则保证了我们样本中目标模型的分布,只要有足够高的接受率和/或足够低的延迟来达到收支平衡,我们可以自由选择生成续写草稿的方法,只要它能提供logits。这里存在几种方法:

集成草稿生成到目标模型。将草稿生成功能整合到目标模型中,并从头开始训练模型。这是Stern等人(【Stern et al., 2018, Blockwise parallel decoding for deep autoregressive models】)使用的策略,他们在Transformer中添加了多个头以生成多个token。

使用序列级蒸馏。使用序列级蒸馏(【Kim and Rush, 2016, Sequence-level knowledge distillation】)来生成第二个模型,该模型能并行预测K个token。Ge等人(【Ge et al., 2022, Lossless acceleration for seq2seq generation with aggressive decoding】)采用了这种策略。

利用目标模型的激活值。将目标模型的一部分激活值作为草稿模型的输入,并使用此输入训练草稿模型。

方法的实用性考量。尽管这些方法可能会产生强大的草稿,但它们需要大量由目标模型生成的数据或对目标模型的更改。特别是序列级蒸馏需要大量的计算预算。这使得它们在大型应用中不太实用。

本文选择的实用方案。虽然大型语言模型能产生更好的样本,但直观上,总有一些“更容易”预测的token,对于这些token,较小的模型可能就足够了。因此,我们可以简单地使用一个较小版本的目标语言模型作为草稿模型,并获得较高的接受率。从工程和工作流程的角度来看,这也将非常方便,因为用于训练目标模型的稳健工具应该已经存在。

A4 实验环境

  • 模型架构:

    • 目标模型: Chinchilla,一个700亿参数的语言模型。
    • 草稿模型: 一个为采样延迟优化的40亿参数模型。该模型与Chinchilla使用相同的分词器和数据集进行训练,但宽度稍小(d_model=6144),且只有8个层。这种“浅而宽”的设计旨在最小化通信开销,使其能在与目标模型相同的硬件上快速采样。
  • 硬件配置:

    • 平台: 16个TPU v4。Chinchilla和草稿模型都在此硬件配置上进行服务。
    • 性能: 在16个TPU v4上,Chinchilla的采样速度为14.1ms/token,而专门设计的草稿模型采样速度为1.8ms/token。
  • 软件配置:

    • 分布式设置: 论文特别指出,为分布式设置选择草稿模型不能简单地选择一个通用的小模型。例如,一个对Chinchilla最优的7B模型在4个TPU v4上延迟最低(5ms/token),但在16个TPU上运行反而会因通信开销增加而变慢。因此,本文训练了一个专门的、能够在16个TPU上高效运行的草稿模型。
  • 数据集与任务:

    • XSum: 一个自然语言摘要任务,使用1-shot提示,共采样11,305个序列,最大序列长度为128。解码参数为核采样$p = 0.8$。
    • HumanEval: 一个代码生成任务,使用100-shot设置,共生成16,400个样本,最大序列长度为512。解码参数为核采样$p = 0.95$,温度为0.8。

A4 实验结果

在XSum和HumanEval上的评估

本文使用Chinchilla模型在XSum和HumanEval两个任务上评估了推测性采样(SpS)的效果,并将结果与标准的自回归采样(ArS)进行了比较。由于数值精度和伪随机种子处理方式的不同,SpS和ArS无法产生完全相同的输出序列,但理论上样本应来自同一分布。通过在这些基准测试中评估,作者从经验上验证了这一点。

核心结果 (Table 1):

  • XSum任务:

    • 使用核采样(Nucleus)时,SpS实现了1.92倍的加速(7.52ms/token vs 14.1ms/token),ROUGE-2得分与ArS相当(0.114 vs 0.112)。
    • 使用贪婪采样(Greedy)时,SpS实现了2.01倍的加速(7.00ms/token vs 14.1ms/token),ROUGE-2得分也与ArS相当(0.156 vs 0.157)。
  • HumanEval任务:

    • 使用核采样时,SpS实现了近2.46倍的显著加速(5.73ms/token vs 14.1ms/token),任务性能指标(pass rate)与ArS相当(47.0% vs 45.1%)。

结论:

  • SpS在两个任务中都获得了显著的加速,同时在基准测试指标上保持了与ArS的同等水平,经验性地验证了该方法不会引入偏差。
  • 特别是在HumanEval和贪婪XSum的情况下,加速后的采样速度超过了硬件为自回归采样设定的理论内存带宽上限(模型大小除以总内存带宽)。

Table 1 | Chinchilla在XSum和HumanEval上使用朴素采样和推测性采样的性能和速度,批量大小为1,K = 4。XSum使用核参数p = 0.8执行,HumanEval使用p = 0.95和温度0.8执行。

不同领域的接受率变化

实验结果表明,接受率与应用领域和解码方法有关。HumanEval任务取得了明显更大的加速。作者推测这是由以下几个因素共同导致的:
1. 代码的重复性: 代码中包含大量常见的子序列(例如 for i in range(len(arr)):),这些序列对于草稿模型来说相对容易猜测。
2. Token的特性: 代码通常被分解为一组更短的token。
3. 温度参数的影响: 较高的温度值会使草稿模型和目标模型的logits分布都变得更尖锐,从而增强了两者之间的一致性。

更长草稿与更频繁评分之间的权衡

作者通过可视化分析了增加草稿token数量K所带来的权衡(见图1)。
* 理论优势与实际限制: 增加K意味着生成相同长度序列所需的目標模型评分调用次数减少,可能带来更大的加速。然而,总循环时间会因草稿模型调用次数的增加和评分时间的微小增加而近似线性增长。
* 接受效率下降: 随着K的增加,接受token的总体比例效率会下降,因为后面的token是否被接受依赖于前面token的接受情况。
* 实际加速效果: 这导致平均加速比会趋于平稳甚至下降。例如,在XSum任务中使用核采样时,当$K=3$时延迟最低。
* 延迟方差: 即使在某些情况下较大的K值能带来稍高的平均加速,它也会增加生成完整序列时间的方差。这对于关注P90、P99延迟的场景可能是有问题的。


图 1 | 左:生成128个token的平均时间及标准差。注意,随着K的增加,总体加速效果趋于平稳甚至倒退,XSum在K=3时达到最优。方差随K的增加而持续增大。中:接受的token平均数除以K+1——这是衡量改进的拒绝方案总体效率的指标,该效率随预读长度的增加而降低。右:由于模型调用次数增加,每个循环的平均时间随K近似线性增长。注意,斜率略高于草稿模型的采样速度,这是由于核采样中的额外开销所致。

A5 结论

本文展示了一种用于加速语言模型解码的新算法和工作流程。推测性采样不需要对目标语言模型的参数或架构进行任何修改,在数值精度范围内可证明是无损的,通过合适的草稿模型可以很好地扩展,并且补充了许多现有的在小批量设置下减少延迟的技术。

我们将该技术优化并扩展到Chinchilla 70B模型,使用了一个易于通过现有基础设施训练的草稿模型,在此过程中证明了它在基准任务和常见解码方法上均能产生巨大的加速。我们通过下游任务经验性地验证了它确实是无损的。

A6 附录

超参数

Table 2 | 草稿模型的超参数

证明

定理1(改进的拒绝采样恢复了目标分布)。给定离散分布p、q和单个草稿样本 $\tilde{x} \sim p$,令X为最终得到的样本。要使 $X=x$ 成立,我们必须要么采样到 $\tilde{x} = x$ 然后接受它,要么在 $\tilde{x}$(任何值)被拒绝后重新采样得到它。因此:

$$ \mathbb{P}(X=x) $$

$$ =\mathbb{P}(\tilde{x}=x) \mathbb{P}(\tilde{x}\ accepted \mid \tilde{x}=x)+\mathbb{P}(\tilde{x}\ rejected) \mathbb{P}(X=x \mid \tilde{x}\ rejected) $$

对于第一项,我们应用接受规则:

$$\begin{aligned} \begin{aligned} &\mathbb{P}(\tilde{x}=x) \mathbb{P}(\tilde{x} \text { accepted } \mid \tilde{x}=x) \\ &=p(x) \min \left(1, \frac{q(x)}{p(x)}\right) \end{aligned} \end{aligned}$$ $$= \min (p(x), q(x))$$

对于第二个条件项,我们应用重采样规则:

$$\mathbb{P}(X=x \mid \tilde{x} \text { rejected })=(q(x)-p(x))_{+}$$

其中 $(.)_+$ 表示:

$$(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}$$

最后,我们计算拒绝的概率:

$$\begin{aligned} \begin{aligned} \mathbb{P}(\tilde{x} \text { rejected }) & =1-\mathbb{P}(\tilde{x} \text { accepted }) \\ & =1-\sum_{x^{\prime}} \mathbb{P}\left(X=x^{\prime}, \tilde{x} \text { accepted }\right) \\ & =1-\sum_{x^{\prime}} \min \left(p\left(x^{\prime}\right), q\left(x^{\prime}\right)\right) \\ & =\sum_{x^{\prime}} \max \left(0, q\left(x^{\prime}\right)-p\left(x^{\prime}\right)\right) \\ & =\sum_{x^{\prime}} q\left(x^{\prime}\right)-\min \left(p\left(x^{\prime}\right), q\left(x^{\prime}\right)\right) \\ & =\sum_{x^{\prime}} \max \left(0, q\left(x^{\prime}\right)-p\left(x^{\prime}\right)\right) \end{aligned} \end{aligned}$$

这等于 $(q(x) - p(x))_+$ 的分母,所以:

$$\mathbb{P}(\tilde{x} \ rejected)\mathbb{P}(X=x|\tilde{x} \ rejected) = \max(0, q(x) - p(x))$$

因此:

$$\begin{aligned} \begin{aligned} \mathbb{P}(X=x) \\ = \min(p(x), q(x)) &+ \max(0, q(x) - p(x)) \\ = q(x) \end{aligned} \end{aligned}$$

我们就恢复了所需的目标分布。