文章标题:深度自回归模型的分块并行解码
作者/机构:Mitchell Stern (加州大学伯克利分校), Noam Shazeer (Google Brain), Jakob Uszkoreit (Google Brain)
核心问题:
深度自回归序列到序列模型,如Transformer和卷积序列到序列模型,虽然在训练时利用并行计算取得了显著的速度提升,但在推理(生成)时仍然是一个固有的顺序过程,即一次生成一个令牌(token)。这对于许多实际应用构成了巨大的挑战。
研究目标:
本文旨在提出一种新颖的解码方案,以克服自回归模型生成过程的顺序性限制,从而在不牺牲或仅牺牲少量模型质量的情况下,显著提升生成速度。
创新点/主要贡献:
本文提出了一种新颖的分块并行解码(blockwise parallel decoding)方案,其核心思想如下:
1. 并行预测:除了基础的自回归模型外,额外训练一组辅助模型,用于并行地预测未来多个时间步的令牌。在解码的每一步,这些模型会并行地提出一个包含k个候选令牌的块。
2. 并行验证:利用基础模型能够并行处理输出序列的特性(如Transformer),对提出的k个候选令牌进行并行打分,以确定其中最长的、与标准贪心解码结果一致的前缀。
3. 前缀接受:接受这个经过验证的最长前缀,从而在一次解码迭代中前进多个时间步,减少了总的解码迭代次数。
核心成果:
- 该方法在不损失模型质量的情况下,可将生成速度相对于贪心解码提升约2倍。
- 结合知识蒸馏和近似解码策略,可以在机器翻译任务中实现高达5倍的解码迭代次数减少(伴随轻微性能下降),在图像超分辨率任务中达到7倍的减少。
- 在实际墙钟时间(wall-clock time)上,最快的模型相比标准贪心解码实现了高达4倍的加速。
- 该方法的一大优势是可以在现有模型之上以最小的修改实现,并且代码已在Tensor2Tensor库中开源。
序列到序列问题定义。在序列到序列问题中,给定一个输入序列 $x = (x_1, \dots, x_n)$,目标是预测相应的输出序列 $y = (y_1, \dots, y_m)$。这些序列可以是机器翻译中的源语言和目标语言句子,也可以是图像超分辨率中的低分辨率和高分辨率图像。解决该问题的一个常用方法是学习一个自回归评分模型 $p(y | x)$,该模型根据从左到右的分解方式进行建模:
推理问题则是在此基础上找到最优的 $y^* = \text{argmax}_y p(y | x)$。
贪心解码作为近似搜索。由于输出空间是指数级巨大的,精确搜索是不可行的。作为一种近似方法,我们可以执行贪心解码来获得一个预测结果 $\hat{y}$。该过程从一个空序列 $\hat{y}$ 和计数器 $j=0$ 开始,重复地用得分最高的令牌 $\hat{y}_{j+1} = \text{argmax}_y p(y_{j+1} | \hat{y}_{\le j}, x)$ 来扩展预测序列,并更新 $j \leftarrow j + 1$,直到满足终止条件。对于语言生成问题,通常在生成一个特殊的序列结束令牌后停止。对于图像生成问题,则简单地解码一个固定的步数。
利用并行性加速解码。标准贪心解码为生成一个长度为 $m$ 的输出需要 $m$ 个步骤,即使模型本身能用恒定的顺序操作次数高效地为序列评分。虽然当词汇表很大时,对长度大于1的输出扩展进行暴力枚举是不可行的,但我们仍然可以尝试通过训练一组辅助模型来提出候选扩展,从而利用模型内部的并行性。
分块并行解码算法。假设原始模型为 $p_1 = p$,并且我们还学习了一组辅助模型 $p_2, \dots, p_k$,其中 $p_i(y_{j+i} | y_{\le j}, x)$ 是给定前 $j$ 个令牌时,第 $(j + i)$ 个令牌为 $y_{j+i}$ 的概率。我们提出以下分块并行解码算法(如图1所示),该算法保证产生与贪心解码相同的预测 $\hat{y}$,但最少仅需 $m/k$ 步。像之前一样,我们从一个空预测 $\hat{y}$ 开始,并设置 $j = 0$。然后重复以下三个子步骤直到满足终止条件:
* 预测(Predict):并行获取块预测 $\hat{y}_{j+i} = \text{argmax}_y p_i(y_{j+i} | \hat{y}_{\le j}, x)$,其中 $i = 1, \dots, k$。
* 验证(Verify):找到最大的 $\hat{k}$,使得对于所有 $1 \le i \le \hat{k}$,都有 $\hat{y}_{j+i} = \text{argmax}_{y_{j+i}} p_1(y_{j+i} | \hat{y}_{\le j+i-1}, x)$ 成立。请注意,根据 $\hat{y}_{j+1}$ 的定义,$\hat{k} \ge 1$ 总是成立的。
* 接受(Accept):用 $\hat{y}_{j+1}, \dots, \hat{y}_{j+\hat{k}}$ 扩展 $\hat{y}$,并设置 $j \leftarrow j + \hat{k}$。
解码子步骤详解。在预测子步骤中,我们找到基础评分模型 $p_1$ 和辅助提议模型 $p_2, \dots, p_k$ 的局部贪心预测。由于这些模型是分离的,每个预测都可以并行计算,因此与单个贪心预测相比,时间损失应该很小。接下来,在验证子步骤中,我们找到提出的长度为 $k$ 的扩展中,原本会被 $p_1$ 产生的最长前缀。如果评分模型能用少于 $k$ 步来处理这 $k$ 个令牌的序列,并且只要有一个以上的令牌是正确的,这个子步骤将有助于节省总体时间。最后,在接受子步骤中,我们用已验证的前缀扩展我们的假设。通过在基础模型和提议模型的预测开始出现分歧时及早停止,我们确保能恢复出与使用 $p_1$ 进行贪心解码所产生的相同输出。
对模型并行能力的要求。该方案提升解码性能的潜力,关键取决于基础模型 $p_1$ 是否有能力并行执行验证子步骤中的所有预测。在我们的实验中,我们使用了Transformer模型【14, Attention is all you need, NIPS 2017, http://arxiv.org/abs/1706.03762】。虽然解码期间执行的总操作数与预测数量呈二次方关系,但必需的顺序操作数是恒定的,与输出长度无关。这使得我们能够并行地对多个位置执行验证子步骤,而无需增加额外的墙钟时间。
减少模型调用次数的动机。当使用Transformer进行评分时,第3节中提出的算法版本每一步需要两次模型调用:一次是在预测子步骤中并行调用 $p_1, \dots, p_k$,另一次是在验证子步骤中调用 $p_1$。这意味着即使有完美的辅助模型,我们也只能将模型调用次数从 $m$ 次减少到 $2m/k$ 次,而不是期望的 $m/k$ 次。
合并验证与预测步骤。事实证明,如果我们假设一个组合的评分和提议模型,我们可以进一步将模型调用次数从 $2m/k$ 减少到 $m/k + 1$,在这种情况下,第 $n$ 次的验证子步骤可以与第 $(n + 1)$ 次的预测子步骤合并。
组合模型实现细节。具体来说,假设我们有一个单一的Transformer模型,它在验证子步骤中,能在常数数量的操作内计算出所有 $i=1, \dots, k$ 和 $i_0=1, \dots, k$ 的 $p_i(y_{j+i_0+i} | \hat{y}_{\le j+i_0}, x)$。例如,这可以通过将最终投影层的维度增加 $k$ 倍,并在每个位置计算 $k$ 个独立的softmax来实现。在预测子步骤中将 $k$ 个未来预测输入模型后调用该模型,即可得到期望的输出。
合并后的解码流程。在这种设置下,当验证过程中计算出 $\hat{k}$ 后,我们实际上已经计算出了所有 $i=1, \dots, k$ 的 $p_i(y_{j+\hat{k}+i} | y_{\le j+\hat{k}}, x)$,这正是下一次解码迭代中预测子步骤所需要的内容。因此,这些子步骤可以合并在一起,除了第一次迭代外,所有后续迭代的模型调用次数都减少了一半。图2展示了这个过程。值得注意的是,尽管在验证子步骤中必须为每个位置计算提议,但所有的预测仍然可以并行进行。
放宽验证标准。到目前为止,我们所描述的分块并行解码方法产生的输出与标准贪心解码完全相同。通过放宽验证过程中使用的标准,我们可以以可能偏离贪心输出为代价,换取额外的加速。
Top-k验证标准。我们不必强求预测与评分模型的预测完全匹配,而是可以要求它位于前 $k$ 个候选项之内。为实现这一点,我们将验证标准替换为:
基于距离的验证标准。在输出空间存在自然距离度量 $d$ 的问题中,我们可以用近似匹配来代替与最高分元素进行精确匹配:
在图像生成的情况下,我们令 $d(u, v) = |u-v|$ 为给定颜色通道内强度 $u$ 和 $v$ 之间的绝对差。
强制最小接受长度。在某个步骤中,第一个非贪心预测有可能是错误的,这种情况下只有一个令牌会被添加到假设中。为了确保最小的加速效果,我们可以要求在每个解码步骤中至少添加 $1 < \grave{} \le k$ 个令牌。设置 $\grave{} = k$ 将对应于使用固定大小为 $k$ 的块进行并行解码。
模型架构修改。我们在实验中实现了第4节描述的组合评分和提议模型。给定一个针对特定任务预训练好的基线Transformer模型,我们在解码器输出层和最终投影层之间插入一个单层前馈网络。该前馈层的隐藏层大小为 $k \times d_{\text{hidden}}$,输出大小为 $k \times d_{\text{model}}$,其中 $d_{\text{hidden}}$ 和 $d_{\text{model}}$ 与网络其余部分使用的层维度相同。输入和 $k$ 个输出中的每一个之间都包含一个残差连接。原始的投影层被同样地应用于 $k$ 个输出中的每一个,以获得 $p_1, \dots, p_k$ 的logits。具体见图3的说明。
训练策略。由于训练时的内存限制,我们无法使用对应于 $p_1, \dots, p_k$ 的 $k$ 个交叉熵损失的均值作为总损失。取而代之的是,我们为每个minibatch随机均匀地选择其中一个子损失,以获得对完整损失的无偏估计。在推理时,所有的logits都可以并行计算,相对于基础模型的边际成本很小。
微调策略的权衡。一个重要的问题是,是否应该为修改后的联合预测任务微调预训练模型的原始参数。如果保持它们冻结,我们能确保原始模型的质量得以保留,但可能会以未来预测精度较低为代价。如果对它们进行微调,我们可能会提高模型的内部一致性,但最终性能可能会有所损失。我们在实验中对这两种选择都进行了研究。
知识蒸馏的应用。知识蒸馏(Knowledge distillation)【4, Geoffrey Hinton et al., Distilling the knowledge in a neural network, arXiv 2015】【7, Yoon Kim and Alexander M Rush, Sequence-level knowledge distillation, arXiv 2016】是一种训练模型学习另一个模型输出的做法,已被证明可以在多种任务上提高性能,甚至在教师和学生模型具有相同架构和模型大小时也可能有效【1, Tommaso Furlanello et al., Born again neural networks, NIPS Workshop 2017】。我们假设序列级别的蒸馏对于分块并行解码可能特别有用,因为它往往会产生一个更具可预测性的训练集,这是由于教师模型打破模式的一致性造成的。对于我们的语言任务,我们使用原始训练数据和蒸馏训练数据进行了实验,以确定其影响程度。蒸馏数据是通过使用一个与基线模型具有相同超参数但随机种子不同的预训练模型进行波束解码(beam decoding)生成的。波束搜索的超参数与【14, Ashish Vaswani et al., Attention is all you need, CoRR 2017】中的相同。
transformer_base超参数集的Transformer。img2img_transformer_b3超参数集。k,训练了多个组合评分和提议模型。训练数据分为原始数据和蒸馏数据两种。参数更新策略分为冻结基线参数和微调所有参数两种。k̂(速度)。k值,性能有所下降。使用蒸馏数据可以减轻性能下降的幅度。平均块大小最高的模型(4.95)其BLEU分数仅比在蒸馏数据上训练的初始模型低0.81。k=2或k=3)会导致BLEU分数大幅下降,而平均块大小改善甚微,表明偶尔只接受一个token的灵活性很重要。k值,在冻结参数和微调参数两种设置下训练组合模型。k=10时,平均接受块大小达到了6.79,意味着解码迭代次数减少了近7倍。k=6)和近似解码的模型。这可能是因为更难的训练任务和近似接受标准引入了轻微的噪声和变化,使图像看起来比基线模型平滑的输出更自然。k无限增长。对于翻译任务,k=8时达到峰值3.3倍加速(对应4.7倍迭代减少);对于超分任务,k=6时达到峰值4.0倍加速(对应5.3倍迭代减少)。更大的k值虽然能进一步减少迭代次数,但其更高的计算成本导致墙钟时间反而增加。k=10近似解码输出的对比图。从视觉上看,分块并行解码的输出质量与标准贪心解码相当。本文提出了一种名为分块并行解码的简单通用技术,用于提升那些架构允许跨输出位置并行评分的深度自回归模型的解码性能。该技术相对容易地可以被添加到现有模型中。实验证明,在机器翻译和条件图像生成任务上,该方法在不损失或仅有微小质量损失的情况下,显著提升了解码速度。
未来的工作计划是研究将该技术与可能正交的方法相结合,例如基于离散潜变量序列的方法【5, Łukasz Kaiser et al., Fast decoding in sequence models using discrete latent variables, arXiv 2018】。