On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes

A1 主要贡献

本文针对自回归序列模型在知识蒸馏(Knowledge Distillation, KD)过程中存在的训练与推理阶段输出序列分布不匹配的问题,提出了一种名为广义知识蒸馏(Generalized Knowledge Distillation, GKD)的新方法。

实验结果表明,与基线KD方法相比,在策略GKD在不同规模的T5学生模型上取得了显著的相对性能提升:摘要任务提升2.1倍,机器翻译任务提升1.7倍,算术推理任务提升1.9倍。

图1:比较GKD与不同学生模型规模下的KD方法。我们使用经过监督FT训练的T5模型作为学生模型。我们使用监督FT的T5-XL(约3B参数)作为教师模型,其性能由水平线表示。监督KD和FT使用真实的输出序列进行训练,而SeqKD则在教师模型生成的输出序列上进行训练。在策略GKD在从学生模型采样的输出序列上进行训练。对于GKD,我们在WMT上使用JSD(0.1),在其他任务上使用前向KL散度。评估时,XSum和GSM8K使用贪婪采样,WMT使用波束搜索。
图1:比较GKD与不同学生模型规模下的KD方法。我们使用经过监督FT训练的T5模型作为学生模型。我们使用监督FT的T5-XL(约3B参数)作为教师模型,其性能由水平线表示。监督KD和FT使用真实的输出序列进行训练,而SeqKD则在教师模型生成的输出序列上进行训练。在策略GKD在从学生模型采样的输出序列上进行训练。对于GKD,我们在WMT上使用JSD(0.1),在其他任务上使用前向KL散度。评估时,XSum和GSM8K使用贪婪采样,WMT使用波束搜索。

A3 背景知识

自回归生成序列模型

我们将输入和输出序列分别表示为x和y。令V表示包含M个词元(token)的词汇表,$y_{<n+1} = (y_1, y_2, ..., y_n)$表示到第n个词元为止已生成的输出序列,$L_y$表示序列y的长度。一个词元级的自回归策略$p(\cdot|y_{<n}, x) \in (0, 1)^M$会输出一个在V中所有词元上的下一词元概率分布,该分布以输入x和输出序列$y_{<n}$为条件。此外,$y \sim p(\cdot|x)$对应于给定输入x采样得到的输出序列y。为简化符号,我们定义$p(y_n|x) := p(y_n|y_{<n}, x)$。自回归生成涉及逐个预测词元,每个预测都基于先前生成的词元。预测第n个词元$y_n$的概率$p(y_n|x)$通过带有温度$\gamma$的softmax函数确定:$p(y_n|x) = \frac{\exp(z_n/\gamma)}{\sum_{i=1}^{M} \exp(z_i/\gamma)}$,其中$z_n$是词元$y_n$的logit分数。较高的$\gamma$值会引入更多随机性,而较低的值则通过偏爱最可能的词语使输出更具确定性。在训练期间,学生的温度保持为1。评估时,我们使用贪婪采样($\gamma \to 0$)或温度采样($\gamma > 0$)。

基于KL的散度

两个概率分布之间的散度是衡量分布相似性的度量,其中KL散度是一种常用的度量。两个离散分布P(C)和Q(C)之间的KL散度由以下公式给出:$D_{KL}(P \| Q) = \sum_{c \in C} P(c) \log \frac{P(c)}{Q(c)}$。KL散度是不对称的:$D_{KL}(P \| Q) \neq D_{KL}(Q \| P)$。因此,我们称$D_{KL}(P \| Q)$为P和Q之间的前向KL散度,而$D_{KL}(Q \| P)$为反向KL散度。在经验数据分布下的前向KL散度对应于最大似然估计,这是我们在监督学习中优化的目标。在模型容量不匹配的情况下,当使用分布$Q_\theta(C)$来近似$P(C)$时,最小化反向和前向KL散度分别会导致“模式寻求”(mode-seeking)和“模式覆盖”(mode-covering)的行为(如图A.16所示)。

广义JSD

虽然KL散度可能是无界的,但一种著名的即使对于支撑集不相交的概率分布也是有界的散度是广义JSD(Jensen-Shannon散度)。JSD($\beta$)使用有界系数$0 < \beta < 1$在前向和反向KL散度之间进行插值:

$$\mathcal{D}_{JSD(\beta)}(P \| Q)=\beta \mathcal{D}_{KL}(P \| \beta P+(1-\beta) Q)+(1-\beta) \mathcal{D}_{KL}(Q \| \beta P+(1-\beta) Q)$$

Huszár(2015)【15, How (not) to train your generative model: Scheduled sampling, likelihood, adversary? by Ferenc Huszár, 2015, arXiv preprint】表明,$\lim_{\beta \to 0} D_{JSD(\beta)}(P \| Q) / \beta = D_{KL}(P \| Q)$。因此,当$\beta$分别接近0和1时,JSD($\beta$)的梯度行为分别类似于前向KL和反向KL。

A2 方法细节

3. 蒸馏自回归序列模型

问题设置。我们有两个不同容量的自回归序列模型,分别用$p_S$(学生)和$p_T$(教师)表示。我们假设学生模型具有可学习的参数$\theta$,并且$p_\theta^S$对$\theta$是可微的。我们还有一个输入数据集X。可选地,我们也可以假设可以访问一个输入-输出序列对的数据集(X, Y)。如果没有提供,这样的数据集可以通过从教师模型中采样序列来生成。对于一个散度D,我们将$p_T$和$p_S$的词元级分布之间的差异定义为:

$$\mathcal{D}(p_{\text{T}}\|p_{\text{S}}^{\theta})(y|x) := \frac{1}{L_{y}} \sum_{n=1}^{L_{y}} \mathcal{D}(p_{\text{T}}(\cdot|y_{<n}, x) \| p_{\text{S}}^{\theta}(\cdot|y_{<n}, x)),$$ <p>这是一个针对输入x和输出序列y的定义。例如,在公式2中使用JSD($\beta$)作为D,会得到$D_{JSD(\beta)}(p_T \| p_\theta^S)(y|x) = \frac{1}{L_y} \sum_n D_{JSD(\beta)}(p_T(\cdot|y_{&lt;n}, x) \| p_\theta^S(\cdot|y_{&lt;n}, x))$。

监督微调(Supervised FT)。如果我们只有一个固定的真实输出序列数据集,但无法查询教师策略,那么一个简单的方法是最小化这些序列在学生策略下的负对数似然:$L_{SFT}(\theta) = \mathbb{E}_{(x,y) \sim (X,Y)}[-\log p_\theta^S(y|x)]$。

序列级知识蒸馏(Sequence-Level KD)。由Kim和Rush(2016)【19, Sequence-level knowledge distillation by Yoon Kim and Alexander M. Rush, 2016, arXiv preprint】提出,SeqKD最大化由教师模型生成的高概率序列的似然,可以看作是在教师生成的输出上进行监督微调。

监督知识蒸馏(Supervised KD)。由Hinton等人(2015)【12, Distilling the knowledge in a neural network by Geoffrey Hinton, Oriol Vinyals, and Jeff Dean, 2015, arXiv preprint】和Sanh等人(2019)【39, Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter by Victor Sanh et al., 2019, arXiv preprint】推广,这是一种广泛使用的技术,其中学生模型被训练来模仿教师模型的词元级概率分布。学生模型$p_S$通过监督目标$L_{SD}$进行训练,该目标针对教师模型$p_T$的目标词元级概率:

$$L_{SD}(\theta) := \mathbb{E}_{(x,y)\sim(X,Y)} \left[ \mathcal{D}_{KL} (p_{\text{T}} \Vert p_{\text{S}}^{\theta})(y|x) \right],$$

其中期望是在数据集的样本上计算的。这种监督目标通过利用教师模型的完整词元级分布,提供了丰富的训练信号。

3.1 广义知识蒸馏(GKD)

训练-推理分布不匹配问题。如上所述,常用的KD方法使用固定的输出序列数据集,无论是真实目标还是教师生成的序列。然而,使用这些方法蒸馏自回归学生模型会导致训练-推理分布不匹配。这是因为学生在推理时自回归生成阶段遇到的部分序列,可能与训练阶段看到的序列大相径庭。

在策略模仿学习的应用。由于在自回归模型中,任何步骤的预测都取决于之前的步骤,这种不匹配会产生级联效应,即早期步骤的预测错误会影响未来的预测,从而导致文本生成质量差。为了解决这种不匹配,我们大量借鉴了模仿学习(IL)。特别是,在策略(on-policy)模仿方法(例如Ross等人,2011【38, A reduction of imitation learning and structured ´ prediction to no-regret online learning by Stephane Ross, Geoffrey Gordon, and Drew Bagnell, 2011, JMLR Workshop and Conference Proceedings】)会迭代地使用学生策略收集序列,获取这些序列的专家标签,然后在这个数据集上重新训练学生。尽管在机器人学和深度强化学习中很受欢迎(Parisotto等人,2015【31, Actor-mimic: Deep multitask and transfer reinforcement learning by Emilio Parisotto, Jimmy Lei Ba, and Ruslan Salakhutdinov, 2015, arXiv preprint】;Kelly等人,2019【18, Hg-dagger: Interactive imitation learning with human experts by Michael Kelly et al., 2019, ICRA】;Agarwal等人,2022【1, Reincarnating reinforcement learning: Reusing prior computation to accelerate progress by Rishabh Agarwal et al., 2022, Advances in Neural Information Processing Systems】),但在策略方法通常不用于蒸馏自回归模型。

在策略蒸馏的反馈机制。将策略性模仿学习扩展到蒸馏领域,我们提出了在策略知识蒸馏。在蒸馏过程中使用策略性数据时,学生模型会从教师模型的logits中获得针对其自生成输出序列中错误词元的具体反馈。这形成了一种类似于我们在强化学习中观察到的反馈循环,有助于最小化训练-推理分布不匹配。此外,随着学生在训练过程中的演进,它生成的数据质量也会提高。给定一个输入x,学生生成输出序列y,并模仿教师在中间状态$y_{&lt;n}$上的词元级分布$p_T(y_n|x)$。具体而言,在策略损失$L_{OD}$由以下公式给出:

$$L_{OD}(\theta) := \mathbb{E}_{x \sim X} \left[ \mathbb{E}_{y \sim p_{\text{S}}(\cdot | x)} \left[ \mathcal{D}_{KL}(p_{\text{T}} \| p_{\text{S}}^{\theta})(y | x) \right] \right],$$

其中,我们不通过学生的采样分布$p_S(\cdot|x)$进行反向传播,这与策略性模仿学习类似。不通过采样过程反向传播使得训练更加稳定和计算高效。在策略知识蒸馏中,训练是在学生可能生成的输出序列上进行的。训练期间,我们使用温度$\gamma = 1$来鼓励学生生成序列的多样性。此外,对于未标记的输入提示,由于模型大小的差异,使用学生模型生成序列比使用教师模型在计算上更便宜。

GKD的统一框架。在策略KD的基础上,我们统一了监督方法和在策略方法,并提出了一个更通用的方法,我们称之为广义KD(GKD)。在GKD中,我们可以选择优化的散度以及用于训练的输出序列。具体来说,我们可以优化教师和学生词元级概率分布之间的任何散度。对于输出序列,GKD使用固定数据集(教师生成或真实标签)和在策略学生生成序列的混合。抽象地说,GKD最小化一个形式如下的目标:

$$L_{\mathrm{GKD}}(\theta):=(1-\lambda) \mathbb{E}_{(x, y) \sim(X, Y)}\left[\mathcal{D}\left(p_{\mathrm{T}} \| p_{\mathrm{S}}^{\theta}\right)(y | x)\right]+\lambda \mathbb{E}_{x \sim X}\left[\mathbb{E}_{y \sim p_{\mathrm{S}}(\cdot | x)}\left[\mathcal{D}\left(p_{\mathrm{T}} \| p_{\mathrm{S}}^{\theta}\right)(y | x)\right]\right]$$

其中$D(p_T, p_S)(y|x)$是教师和学生分布之间的散度(公式2),$\lambda \in [0, 1]$是一个超参数,用于控制学生数据比例,即在策略学生生成输出的比例。与在策略KD类似,我们不通过学生的采样过程反向传播梯度。在策略KD和监督KD分别是GKD的实例,其中散度D分别设置为前向KL,学生数据比例$\lambda$分别为1和0。尽管如此,GKD允许$\lambda$和散度有其他选择,我们在这项工作中进行了探索。

备注:学生模型初始化。与随机初始化的学生模型不同,我们假设可以访问一个能够生成足够质量序列的学生模型,教师可以对其提供反馈。在我们的实验中,我们从经过监督微调(supervised FT)的学生模型开始。这类似于两阶段的RLHF训练,该方法广泛用于语言模型,即首先进行SFT,然后进行在线RL微调。因此,GKD可以利用RLHF的超参数调优经验,并且可以与RLHF结合,计算开销小,且无需额外超参数。

GKD中散度的选择。虽然前向KL散度常用于蒸馏,但它要求学生覆盖教师词元级分布$p_T(\cdot|y_{&lt;n}, x)$的整个支撑集。这样做可能会导致学生为在$p_T(\cdot|y_{&lt;n}, x)$下概率较低的词元v分配概率质量,从而可能导致幻觉和低质量的生成。当学生的模型容量远低于教师时,使用温度采样时很可能会出现这个问题(例如,图A.16)。或者,模式寻求型散度,如反向KL散度,会优先考虑教师分配高概率的词元,这可以避免低质量的生成,但代价是对于给定输入生成的样本多样性较低。我们的实验表明,最优散度似乎是任务依赖的。总的来说,在选择GKD散度时,需要考虑特定任务的多样性和性能权衡(例如,图4、10)。

3.2 强化学习微调 + 在策略GKD

结合RL的目标。在某些任务中,从教师模型进行蒸馏可能只是我们主要目标的一个代理,而这个主要目标也可能是不可微的。我们可以使用强化学习(RL)直接优化这个目标。方便的是,在策略GKD可以很容易地与来自人类(RLHF)或AI反馈(RLAIF)的RL微调相结合,因为它只需要学生的输出样本。

正则化RL微调目标。实际上,如果我们希望为一个标量奖励r优化学生策略,同时保持其与教师策略的接近度,我们会得到一个形式如下的正则化RL微调目标:

$$\mathbb{E}_{x \sim X}\Big[(1-\alpha) \underbrace{E_{y \sim p_{\mathrm{S}}^{\theta}(\cdot \mid x)}[r(y)]}_{\text {RL objective }}-\alpha \underbrace{\mathbb{E}_{y \sim p_{\mathrm{S}}(\cdot \mid x)}\left[\mathcal{D}\left(p_{\mathrm{T}} \| p_{\mathrm{S}}^{\theta}\right)(y \mid x)\right]}_{\text {Generalized On-Policy Distillation }}\Big],$$

其中$\alpha \in [0, 1]$控制蒸馏损失相对于RL目标的强度。当$\alpha = 1$时,它将只执行蒸馏。上述目标使我们能够最大化奖励,同时通过蒸馏提升其他模型能力,这可能减少在将语言模型与人类偏好对齐时通用模型能力的下降,即“对齐税”(Ouyang等人,2022【30, Training language models to follow instructions with human feedback by Long Ouyang et al., 2022, Advances in Neural Information Processing Systems】)。我们将上述思想应用于使用RLAIF减轻幻觉,同时通过蒸馏提高下游性能(图5)。

备注:与现有RLHF流程的集成。在RLHF或RLAIF中,我们通常使用反向KL散度来约束学习到的策略接近初始策略。如果只想对现有的RL微调工作流程进行微小修改,我们建议在将GKD与RL集成时使用反向KL或JSD(0.9)。

A4 实验环境

A4 实验结果

4.1 案例研究:摘要生成 (XSum)

图2:在XSum上比较GKD与基线方法,从T5-XL蒸馏到T5-large。在策略GKD变体通常优于基线方法。
图2:在XSum上比较GKD与基线方法,从T5-XL蒸馏到T5-large。在策略GKD变体通常优于基线方法。
图3:扩展训练数据。我们使用温度采样(γ = 1)评估蒸馏后的T5-small。GKD比基线方法数据效率更高。
图3:扩展训练数据。我们使用温度采样(γ = 1)评估蒸馏后的T5-small。GKD比基线方法数据效率更高。
图4:散度对性能和多样性的影响。利用不同散度的在策略GKD,我们通过改变采样温度来评估蒸馏后学生模型的生成质量和多样性之间的权衡。我们使用Self-BLEU量化多样性,其中100分表示确定性输出,0分表示最大多样性。从前向KL散度过渡到反向KL散度,通过广义JSD,导致多样性降低,这归因于散度增强的模式寻求特性。模式寻求散度通常能产生更优的质量,尤其是在高温(γ = 1)下。降低温度会限制多样性,同时缩小不同散度之间的性能差异。
图4:散度对性能和多样性的影响。利用不同散度的在策略GKD,我们通过改变采样温度来评估蒸馏后学生模型的生成质量和多样性之间的权衡。我们使用Self-BLEU量化多样性,其中100分表示确定性输出,0分表示最大多样性。从前向KL散度过渡到反向KL散度,通过广义JSD,导致多样性降低,这归因于散度增强的模式寻求特性。模式寻求散度通常能产生更优的质量,尤其是在高温(γ = 1)下。降低温度会限制多样性,同时缩小不同散度之间的性能差异。
图5:RLAIF + 在策略GKD。我们展示了在XSum上奖励最大化和摘要性能之间的权衡。我们报告了相对于原始T5-base学生的改进。我们使用T5-XXL NLI分类器的文本蕴含分数作为奖励。α控制了带有JSD (0.9)的在策略GKD损失的强度。随着α增加,ROUGE-2增加,而事实一致性的改善减少。作为比较,我们展示了12倍大的T5-XL教师的相对性能。RLEF*对应于Roit等人(2023)的RLAIF方法,其中学生模型被正则化到原始学生模型本身而不是教师模型。在策略GKD + RL与RLEF*相比实现了更高的ROUGE-2,同时生成了比教师模型更具事实一致性的摘要。
图5:RLAIF + 在策略GKD。我们展示了在XSum上奖励最大化和摘要性能之间的权衡。我们报告了相对于原始T5-base学生的改进。我们使用T5-XXL NLI分类器的文本蕴含分数作为奖励。α控制了带有JSD (0.9)的在策略GKD损失的强度。随着α增加,ROUGE-2增加,而事实一致性的改善减少。作为比较,我们展示了12倍大的T5-XL教师的相对性能。RLEF*对应于Roit等人(2023)的RLAIF方法,其中学生模型被正则化到原始学生模型本身而不是教师模型。在策略GKD + RL与RLEF*相比实现了更高的ROUGE-2,同时生成了比教师模型更具事实一致性的摘要。

4.2 机器翻译 (WMT)

图6:在WMT en→de任务上改变GKD中的学生数据比例和散度。评估时,我们使用波束搜索并报告蒸馏后学生相对于原始学生BLEU分数的提升。结果是三次随机种子实验的平均值。我们观察到,仅使用学生生成的输出样本优于其他GKD变体。我们使用在WMT上监督微调的T5-XL(约3B参数)作为教师,其BLEU分数为28。(左)我们使用T5-small(77M参数)作为学生,其BLEU分数为25.58。(右)学生为T5-base(250M参数),其BLEU分数为26.98。
图6:在WMT en→de任务上改变GKD中的学生数据比例和散度。评估时,我们使用波束搜索并报告蒸馏后学生相对于原始学生BLEU分数的提升。结果是三次随机种子实验的平均值。我们观察到,仅使用学生生成的输出样本优于其他GKD变体。我们使用在WMT上监督微调的T5-XL(约3B参数)作为教师,其BLEU分数为28。(左)我们使用T5-small(77M参数)作为学生,其BLEU分数为25.58。(右)学生为T5-base(250M参数),其BLEU分数为26.98。

4.3 算术推理 (GSM8K)

图7:在GSM8K上对GKD进行消融实验。我们从微调后的T5-XL蒸馏到T5-Base,它们在使用贪心采样时分别获得了27.9和10.16的准确率。
图7:在GSM8K上对GKD进行消融实验。我们从微调后的T5-XL蒸馏到T5-Base,它们在使用贪心采样时分别获得了27.9和10.16的准确率。
图8:在GSM8K上改变在策略数据的比例。当我们将学生生成数据的比例增加到25%以上时,性能通常会提高。
图8:在GSM8K上改变在策略数据的比例。当我们将学生生成数据的比例增加到25%以上时,性能通常会提高。
图9:在GSM8K上使用少样本CoT提示进行蒸馏。在策略GKD显著优于其他方法。作为参考,我们还提供了GPT-3 davinci-002以及PaLM(540B)的结果(不使用计算器)。我们分别对在策略和监督GKD使用前向KL和反向KL。
图9:在GSM8K上使用少样本CoT提示进行蒸馏。在策略GKD显著优于其他方法。作为参考,我们还提供了GPT-3 davinci-002以及PaLM(540B)的结果(不使用计算器)。我们分别对在策略和监督GKD使用前向KL和反向KL。

4.4 任务无关蒸馏:指令微调

图10:在FLAN上的任务无关蒸馏。使用反向KL的在策略GKD优于其他方法。MMLU和BBH基准套件的评估指标都是少样本提示准确率(精确匹配),我们对所有任务取未加权平均。这些评估基准是未参与训练的(不包含在蒸馏数据中)。这里,由于从教师模型生成数据的计算效率低下,我们没有运行SeqKD。教师模型FLAN T5-XL在MMLU和BBH上的准确率分别为52.4%和41%,而学生模型T5-large在MMLU和BBH上的准确率分别为35.6%和31.25%。
图10:在FLAN上的任务无关蒸馏。使用反向KL的在策略GKD优于其他方法。MMLU和BBH基准套件的评估指标都是少样本提示准确率(精确匹配),我们对所有任务取未加权平均。这些评估基准是未参与训练的(不包含在蒸馏数据中)。这里,由于从教师模型生成数据的计算效率低下,我们没有运行SeqKD。教师模型FLAN T5-XL在MMLU和BBH上的准确率分别为52.4%和41%,而学生模型T5-large在MMLU和BBH上的准确率分别为35.6%和31.25%。

A7 补充细节

知识蒸馏

监督知识蒸馏(Bucilua等人,2006【5, Model compression by Cristian Bucilua, Rich Caruana, and Alexandru Niculescu-Mizil, 2006, Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining】;Hinton等人,2015【12, Distilling the knowledge in a neural network by Geoffrey Hinton, Oriol Vinyals, and Jeff Dean, 2015, arXiv preprint】)是一种经典方法,并已成功用于蒸馏自回归模型(Sanh等人,2019【39, Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter by Victor Sanh et al., 2019, arXiv preprint】)。另一种蒸馏此类模型的方法是序列级知识蒸馏(Kim & Rush, 2016【19, Sequence-level knowledge distillation by Yoon Kim and Alexander M. Rush, 2016, arXiv preprint】)。在策略GKD显著优于监督KD和SeqKD(图1)。其他KD方法训练学生匹配从教师那里获得的不同量,例如隐藏状态(Jiao等人,2020【16, Tinybert: Distilling bert for natural language understanding by Xiaoqi Jiao et al., 2020, Findings of the Association for Computational Linguistics: EMNLP 2020】)或注意力分数(Wang等人,2020【43, Minilm: Deep selfattention distillation for task-agnostic compression of pre-trained transformers by Wenhui Wang et al., 2020, Advances in Neural Information Processing Systems】)。然而,这些方法都没有建立蒸馏与模仿学习之间的联系,纯粹的监督方法可能会遭受训练-推理不匹配的影响,这也被称为暴露偏差(Ranzato等人,2015【35, Sequence level training with recurrent neural networks by Marc’Aurelio Ranzato et al., 2015, arXiv preprint】;Bengio等人,2015【3, Scheduled sampling for sequence prediction with recurrent neural networks by Samy Bengio et al., 2015, Advances in neural information processing systems】)。虽然He等人(2019)【11, Exposure bias versus selfrecovery: Are distortions really incremental for autoregressive text generation? by Tianxing He et al., 2019, arXiv preprint】认为这种不匹配可能不重要,但多篇论文证明暴露偏差会导致糟糕的文本生成(Zhang等人,2019【49, Bridging the gap between training and inference for neural machine translation by Wen Zhang et al., 2019, Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics】;Chiang & Chen,2021【6, Relating neural text degeneration to exposure bias by Ting-Rui Chiang and Yun-Nung Chen, 2021, arXiv preprint】;Arora等人,2022【2, Why exposure bias matters: An imitation learning perspective of error accumulation in language generation by Kushal Arora et al., 2022, arXiv preprint】)。

ImitKD与f-distill

ImitKD(Lin等人,2020【22, Autoregressive knowledge distillation through imitation learning by Alexander Lin, Jeremy Wohlwend, Howard Chen, and Tao Lei, 2020, arXiv preprint】)通过从学生和固定数据集中采样序列来识别这种联系,但没有进一步推进这个想法。与GKD不同,ImitKD没有探索纯粹的在策略数据收集,也没有整合RL微调。此外,ImitKD在词元级别保持前向KL,当可以访问教师的对数概率而不仅仅是样本时,这并非必要。此外,GKD展示了该思想的可扩展性,处理的学生模型比ImitKD探索的模型大约26倍。ImitKD可以被看作是使用前向KL和对λ采用非递增调度(一个简单的选择是λ=0.5)的GKD。最近,f-distill(Wen等人,2023【45, f-divergence minimization for sequence-level knowledge distillation by Yuqiao Wen, Zichao Li, Wenyu Du, and Lili Mou, 2023, arXiv preprint】)将序列级KD表述为最小化一个f-散度,并提出了一个基于词元级学生和教师分布之间总变差距离的可行目标。实质上,ImitKD和f-distill都是GKD的特定实例,我们证明它们在经验上导致的结果比在策略GKD差(图2,9)。

与MiniLLM的比较

同期的工作MiniLLM(Gu等人,2023【10, Knowledge distillation of large language models by Yuxian Gu, Li Dong, Furu Wei, and Minlie Huang, 2023, arXiv preprint】)也利用了与模仿学习的联系,并将蒸馏框架化为一个RL问题。特别是,MiniLLM使用策略梯度方法在序列级别上优化教师和学生之间的反向KL散度(而似然最大化是前向的)。然而,我们认为GKD更简单、更稳定,更接近监督训练,因为它不通过学生的采样过程进行反向传播。事实上,MiniLLM依赖于一些稳定技巧来解决高方差、奖励操纵和生成长度偏差问题。GKD也更通用,因为它也可以与前向KL或JSD等其他散度一起使用,这些散度可能比反向KL表现更好(图6,7)。

RL微调

现在有许多语言模型通过RL进行微调的例子,无论是为了优化某个指标的奖励(Wu等人,2018【46, A study of reinforcement learning for neural machine translation by Lijun Wu et al., 2018, arXiv preprint】),还是使用人类反馈学习(Ouyang等人,2022【30, Training language models to follow instructions with human feedback by Long Ouyang et al., 2022, Advances in Neural Information Processing Systems】)。在这些方法中,通常会将RL微调的模型正则化到初始(通常是监督微调的)模型。然而,据我们所知,我们是第一个同时进行蒸馏和RL微调的(图5)。这可能看起来很自然,但从优化的角度来看,这是非常不同的,因为它将对初始策略的正则化改变为对教师策略的正则化,我们通过经验证明这是一种可行的方法。

带推理轨迹的蒸馏

思维链提示(Nye等人,2021【29, Show your work: Scratchpads for intermediate computation with language models by Maxwell Nye et al., 2021, arXiv preprint】;Wei等人,2022【44, Chain of thought prompting elicits reasoning in large language models by Jason Wei et al., 2022, arXiv preprint】)最近证明了LLMs可以通过提示逐步解决复杂的推理任务。这个想法很快被应用于KD,通过为学生微调扩展教师数据集的CoT提示(Magister等人,2022【26, Teaching small language models to reason by Lucie Charlotte Magister et al., 2022, arXiv preprint】;Ho等人,2022【13, Large language models are reasoning teachers by Namgyu Ho, Laura Schmid, and Se-Young Yun, 2022, arXiv preprint】;Hsieh等人,2023【14, Distilling step-by-step! outperforming larger language models with less training data and smaller model sizes by Cheng-Yu Hsieh et al., 2023, arXiv preprint】)。蒸馏仍然是以监督方式进行的,并且可以考虑其他类型的增强提示(Li等人,2022【21, Explanations from large language models make small reasoners better by Shiyang Li et al., 2022, arXiv preprint】;Mukherjee等人,2023【27, Orca: Progressive learning from complex explanation traces of gpt-4 by Subhabrata Mukherjee et al., 2023, arXiv preprint】)。我们采用相同的方法,但将其与各种散度的在策略蒸馏相结合。这显示了GKD的多功能性,并改进了纯粹的监督方法,如我们在GSM8K上的结果所示(图9)。

在推测解码中的应用

Zhou等人(2023)【51, Distillspec: Improving speculative decoding via knowledge distillation by Yongchao Zhou et al., 2023, arXiv preprint】和Liu等人(2023)【24, Online speculative decoding by Xiaoxuan Liu et al., 2023, arXiv preprint】应用GKD来改善草稿模型和目标模型之间的一致性,以从推测解码中获得更好的推理加速。

A5 结论

本文提出了广义知识蒸馏(GKD)方法,以解决蒸馏自回归语言模型时出现的训练-推理分布不匹配问题。在摘要生成、机器翻译和算术推理这三个语言生成任务上,GKD的性能始终优于常用的知识蒸馏方法。我们进一步证明,GKD可以与强化学习相结合,以在蒸馏大型教师模型知识的同时优化序列级奖励。我们相信这可以改进目前广泛用于语言模型的RLHF训练阶段。未来一个有趣的研究方向是将GKD扩展到用于音频、视频和文本到图像生成的自回归序列模型。我们希望我们的工作能对致力于提升生成式自回归序列模型性能和效率的研究人员和实践者有所价值。

A6 附录

A.1 自蒸馏

自蒸馏。我们研究了GKD是否适用于自蒸馏(Yim等人,2017【47, A gift from knowledge distillation: Fast optimization, network minimization and transfer learning by Junho Yim et al., 2017, Proceedings of the IEEE conference on computer vision and pattern recognition】),即我们希望将知识从一个教师模型转移到一个具有相同架构和大小的学生模型。为了研究这一点,我们考虑在GSM8K上使用FLAN-T5 large作为学生和教师进行自蒸馏,其中教师模型在GSM8K上进行了监督微调。如图A.11所示,自蒸馏后的学生模型在测试集上的表现超过了教师。此外,使用学生生成数据进行蒸馏的效果优于监督KD,其中在策略GKD表现最佳。

图A.11:在GSM8K上的自蒸馏。这里,GKD对应于在策略GKD(λ = 1)。在策略GKD变体优于包括监督KD在内的其他方法。教师FLAN T5-Large在GSM8K上进行了监督微调,准确率为20.5%,而学生FLAN T5-large(未在GSM8K上训练)在测试集上的准确率为14.4%。
图A.11:在GSM8K上的自蒸馏。这里,GKD对应于在策略GKD(λ = 1)。在策略GKD变体优于包括监督KD在内的其他方法。教师FLAN T5-Large在GSM8K上进行了监督微调,准确率为20.5%,而学生FLAN T5-large(未在GSM8K上训练)在测试集上的准确率为14.4%。

A.2 T5 模型

基础检查点。我们从LM-adapted T5v1.1模型开始。这些LM-adapted模型从T5v1.1初始化,并在T5论文(Raffel等人,2020【34, Exploring the limits of transfer learning with a unified text-to-text transformer by Colin Raffel et al., 2020, The Journal of Machine Learning Research】)中讨论的LM目标上额外训练了10万步。这些检查点是开源的。

模型初始化。在我们的实验中,我们通过在原始训练数据集上进行进一步的监督微调来初始化用于蒸馏的学生和教师模型,具体如下:

优化器。与T5和FLAN-T5类似,我们的实验使用Adafactor优化器(Shazeer & Stern, 2018【40, Adafactor: Adaptive learning rates with sublinear memory cost by Noam Shazeer and Mitchell Stern, 2018, International Conference on Machine Learning】)。

GKD的计算成本。包括基线在内的所有方法都从监督微调的学生检查点开始,这需要在最小的TPUv3(8核)上训练几个小时。在GSM8K上,对于学生-教师大小比例为38倍、12倍和3.8倍的情况,学生采样的计算开销分别比从固定输出数据集中采样高出约1.8倍、2倍和2.2倍。对于RLHF + GKD,计算开销相对较小,因为我们只运行推理来获取教师的logits而不是学生的logits。此外,在实际应用中,大部分成本来自推理时的服务成本,而非微调。具体来说,如果在微调期间从学生模型采样成本太高,那么向用户提供该模型服务(用户量可能从数万到数十亿)的成本也可能过高。总的来说,在策略GKD带来的性能优势可能是值得付出计算成本的,特别是与RLHF结合时。

A.3 XSUM

学习率搜索。我们对学习率在{0.0001, 0.0003, 0.001}范围内进行了搜索,发现0.0003对T5-base和T5-large效果最好,而0.001对T5-small性能最佳。因此,我们默认使用0.0003的学习率,除非报告T5-small的结果时使用0.001。我们发现反向KL对较高的学习率更敏感,因此在使用反向KL时,我们对所有模型默认使用0.0003。

教师Softmax温度。当使用贪婪采样进行评估时,我们将教师温度设置为1。然而,当报告使用温度采样($\gamma = 1$)的学生性能时,如在图2和图3中,我们将学生的教师温度设置为0.1。

超参数详情。表A.1列出了XSum实验的详细超参数。

表A.1:XSum实验的超参数详情。
表A.1:XSum实验的超参数详情。

GKD消融实验。我们在图A.12和A.13中对不同学生模型尺寸的GKD的不同散度和学生数据比例进行了消融研究。在策略和混合变体始终优于监督变体。当使用温度采样进行评估时,模式寻求型散度表现更好,而使用贪婪采样时,散度的选择对性能影响不大。

图A.12:在XSum上使用温度采样(γ = 1)评估GKD的消融实验。我们从监督微调的T5-XL模型蒸馏到不同尺寸的学生T5模型。这里,我们使用温度采样进行评估,并在训练期间将教师温度设置为0.1。在上图中,我们报告了蒸馏后学生的ROUGE-2分数。使用反向KL和JSD(0.9)的在策略GKD方法表现最好,而前向KL表现较差。
图A.12:在XSum上使用温度采样(γ = 1)评估GKD的消融实验。我们从监督微调的T5-XL模型蒸馏到不同尺寸的学生T5模型。这里,我们使用温度采样进行评估,并在训练期间将教师温度设置为0.1。在上图中,我们报告了蒸馏后学生的ROUGE-2分数。使用反向KL和JSD(0.9)的在策略GKD方法表现最好,而前向KL表现较差。
图A.13:在XSum上使用贪婪采样评估GKD的消融实验。我们从监督微调的T5-XL模型蒸馏到不同尺寸的学生T5模型。这里,我们使用贪婪采样进行评估,并在训练期间将学生和教师的温度都设置为1。当使用贪婪采样评估时,教师模型的ROUGE-2分数为22,而学生T5-small、base和large模型的分数分别为13.4、17.9和19.6。在上图中,我们报告了蒸馏后学生的ROUGE-2分数。在策略GKD方法表现最好,不同散度之间的差异很小。此外,在策略和混合变体明显优于监督变体。
图A.13:在XSum上使用贪婪采样评估GKD的消融实验。我们从监督微调的T5-XL模型蒸馏到不同尺寸的学生T5模型。这里,我们使用贪婪采样进行评估,并在训练期间将学生和教师的温度都设置为1。当使用贪婪采样评估时,教师模型的ROUGE-2分数为22,而学生T5-small、base和large模型的分数分别为13.4、17.9和19.6。在上图中,我们报告了蒸馏后学生的ROUGE-2分数。在策略GKD方法表现最好,不同散度之间的差异很小。此外,在策略和混合变体明显优于监督变体。

A.4 GSM8K

训练与评估。我们使用由Magister等人(2022)【26, Teaching small language models to reason by Lucie Charlotte Magister et al., 2022, arXiv preprint】从Palm-540B生成的CoT输出来进行训练。我们在GSM8K数据集(Cobbe等人,2021【9, Training verifiers to solve math word problems by Karl Cobbe et al., 2021, arXiv preprint】)的原始测试集上报告准确率。我们使用蒸馏训练结束时的检查点报告结果,结果是3个随机种子实验的平均值。

超参数详情。表A.2列出了GSM8K实验的详细超参数。

表A.2:GSM8K实验的超参数详情。
表A.2:GSM8K实验的超参数详情。

少样本CoT提示。以下是实验中使用的4-shot CoT提示:

Q: 树林里有15棵树。园丁今天将在树林里种树。他们完成后,将会有21棵树。园丁今天种了多少棵树?
A: 原来有15棵树。后来种了一些树后有21棵树。所以一定是21 - 15 = 6。答案是6。
Q: 如果停车场有3辆车,又来了2辆车,停车场现在有多少辆车?
A: 原来有3辆车。又来了2辆车。3 + 2 = 5。答案是5。
Q: 莉亚有32块巧克力,她姐姐有42块。如果她们吃了35块,她们总共还剩多少块?
A: 原来,莉亚有32块巧克力。她姐姐有42块。所以她们总共有32 + 42 = 74块。吃了35块后,她们还剩74 - 35 = 39块。答案是39。
Q: 杰森有20个棒棒糖。他给了丹尼一些棒棒糖。现在杰森有12个棒棒糖。杰森给了丹尼多少个棒棒糖?
A: 杰森开始有20个棒棒糖。给了一些给丹尼后他有12个。所以他给了丹尼20 - 12 = 8个。答案是8。

GKD消融实验。图A.14展示了在GSM8K上使用4-shot CoT的GKD变体消融实验结果。评估时使用贪婪采样,并报告蒸馏后学生模型测试准确率的提升。结果为三个随机种子的平均值。仅使用学生生成的输出样本通常优于其他GKD变体。

图A.14:在GSM8K上对使用4-shot CoT的GKD变体进行消融实验。评估时,我们使用贪婪采样,并报告蒸馏后学生模型测试准确率的提升。结果是三次随机种子实验的平均值。仅使用学生生成的输出样本通常优于其他GKD变体。我们使用监督微调的T5-XL作为教师,其准确率为27.9。(左)我们使用T5-small作为学生,其准确率为4.585。(右)学生为T5-base,准确率为20.5。
图A.14:在GSM8K上对使用4-shot CoT的GKD变体进行消融实验。评估时,我们使用贪婪采样,并报告蒸馏后学生模型测试准确率的提升。结果是三次随机种子实验的平均值。仅使用学生生成的输出样本通常优于其他GKD变体。我们使用监督微调的T5-XL作为教师,其准确率为27.9。(左)我们使用T5-small作为学生,其准确率为4.585。(右)学生为T5-base,准确率为20.5。

A.5 WMT

评估与报告。我们使用与Raffel等人(2020)【34, Exploring the limits of transfer learning with a unified text-to-text transformer by Colin Raffel et al., 2020, The Journal of Machine Learning Research】相同的超参数进行波束搜索评估。我们报告训练后最终检查点的性能。为了减少结果的方差,我们报告的结果是3个随机种子实验的平均值。

与基线对比。图A.15比较了GKD与ImitKD和f-distill在WMT上的表现。在策略GKD相比ImitKD和f-distill在BLEU提升上分别高出53%和162%。

图A.15:在WMT上比较GKD与ImitKD和f-distill。这里,GKD对应于WMT上表现最佳的变体,即λ = 1(在策略)和JSD(0.1)。在策略GKD在small和base模型上的平均BLEU提升分别比ImitKD高53%,比f-distill高162%。
图A.15:在WMT上比较GKD与ImitKD和f-distill。这里,GKD对应于WMT上表现最佳的变体,即λ = 1(在策略)和JSD(0.1)。在策略GKD在small和base模型上的平均BLEU提升分别比ImitKD高53%,比f-distill高162%。

超参数详情。表A.3列出了WMT英德翻译实验的详细超参数。

表A.3:WMT en-de实验的超参数详情。
表A.3:WMT en-de实验的超参数详情。

A.6 指令微调

超参数详情。表A.4列出了FLAN指令微调的详细超参数。

表A.4:FLAN指令微调的超参数详情。
表A.4:FLAN指令微调的超参数详情。

A.7 模式寻求 vs. 模式覆盖 KL散度

概念解释。图A.16展示了在存在容量不匹配时,最小化前向和反向KL散度学习到的分布$Q_\theta$。图中,$P$是一个混合分布,而$Q_\theta$是一个单峰高斯分布。反向KL散度是模式寻求的,因为它强制$Q_\theta$在$P$为零的地方也为零,因此使其集中在其中一个模式上(最后一个图)。然而,前向KL散度是模式覆盖的,因为它确保在$P$有质量的任何地方,$Q_\theta$下也有一些质量。

图A.16:容量不匹配下的模式寻求与模式覆盖KL散度。我们展示了在最小化一个混合分布P和一个单峰高斯分布Qθ之间的前向和反向KL散度时,学习到的分布Qθ。反向KL是模式寻求的,因为它迫使Qθ在P为零的地方也为零,从而使其集中在其中一个模式上(最后一个图)。然而,前向KL是模式覆盖的,因为它确保在P有质量的任何地方,Qθ下也有一些质量。参考Le (2017)可以复现此图。
图A.16:容量不匹配下的模式寻求与模式覆盖KL散度。我们展示了在最小化一个混合分布P和一个单峰高斯分布Qθ之间的前向和反向KL散度时,学习到的分布Qθ。反向KL是模式寻求的,因为它迫使Qθ在P为零的地方也为零,从而使其集中在其中一个模式上(最后一个图)。然而,前向KL是模式覆盖的,因为它确保在P有质量的任何地方,Qθ下也有一些质量。参考Le (2017)可以复现此图。