作者/机构: Joshua Ainslie∗, James Lee-Thorp∗, Michiel de Jong∗ Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai (Google Research)
自回归解码器的推理过程是Transformer模型的一个严重瓶颈,主要因为在每个解码步骤都需要加载解码器权重以及所有的注意力键(keys)和值(values),这带来了巨大的内存带宽开销【Shazeer, 2019, Fast transformer decoding: One write-head is all you need, arXiv preprint】。多查询注意力(Multi-query attention, MQA)通过使用多个查询头(query heads)但共享单一的键和值头,可以显著减少加载键和值所带来的内存带宽。然而,MQA可能会导致模型质量下降和训练不稳定,并且为追求更快的推理速度而专门训练一个独立的模型在实践中可能并不可行。此外,尽管像PaLM【Chowdhery et al., 2022, Palm: Scaling language modeling with pathways】这样的一些语言模型已经采用了MQA,但许多公开可用的模型,如T5【Raffel et al., 2020, Exploring the limits of transfer learning with a unified text-to-text transformer, JMLR】和LLaMA【Touvron et al., 2023, Llama: Open and efficient foundation language models】,仍然使用多头注意力(Multi-head attention, MHA)。
针对以上问题,本文提出了两大贡献以加速大型语言模型的推理速度:
1. 提出了一种增量训练(uptraining)方法:该方法能够将已有的多头注意力(MHA)语言模型检查点,以原始预训练计算成本的一小部分(例如5%),转换为使用多查询注意力(MQA)的模型。这为同时获得高质量的MHA检查点和快速推理的MQA模型提供了一种经济高效的途径。
2. 提出了分组查询注意力(Grouped-query attention, GQA):GQA是MHA和MQA之间的一种插值方案。它将查询头分组,每组共享一个键和值头,从而使用的键值头数量介于MQA(1个)和MHA(与查询头数量相等)之间。实验证明,经过增量训练的GQA模型在质量上接近MHA,同时在推理速度上与MQA相当。
从多头模型生成多查询模型的两步法。这个过程分为两个步骤:首先,转换检查点;其次,进行额外的预训练,让模型适应其新结构。图1展示了将多头检查点转换为多查询检查点的过程。具体操作是将所有头的键(key)和值(value)的投影矩阵进行均值池化(mean pooled),融合成单一的投影矩阵。我们发现,这种方法比从多个头中选择单一的键和值头,或从头开始随机初始化新的键和值头效果更好。
增量预训练。转换后的检查点会使用与原始训练相同的预训练方案,在其原始训练步数的α比例上进行进一步的预训练。
GQA的定义与转换。分组查询注意力将查询头(query heads)分为G个组,每个组共享一个单一的键头(key head)和值头(value head)。GQA-G指的是有G个分组的分组查询。其中,GQA-1(只有一个组,因此只有一个键和值头)等同于MQA;而GQA-H(分组数等于头数)等同于MHA。图2展示了分组查询注意力与多头/多查询注意力的比较。当将一个多头检查点转换为GQA检查点时,我们通过对该组内所有原始头进行均值池化来构建每个组的键和值头。
GQA作为MHA和MQA之间的权衡。一个中间的分组数量可以得到一个插值模型,其质量高于MQA但速度快于MHA,我们后续将证明这代表了一个有利的权衡。从MHA到MQA将H个键和值头减少到单个键和值头,从而将键值缓存(key-value cache)的大小以及需要加载的数据量减少了H倍。然而,更大的模型通常会扩展头的数量,因此MQA在内存带宽和模型容量上都代表了更激进的削减。GQA则允许我们随着模型尺寸的增加,保持带宽和容量的同比例减少。
GQA对大模型的额外优势。此外,对于更大的模型,来自注意力的内存带宽开销相对较小,因为KV缓存随模型维度线性扩展,而模型的FLOPs和参数则随模型维度的平方扩展。最后,针对大型模型的标准分片技术(standard sharding)会根据模型分区的数量复制单个键和值头【Pope et al., 2022, Efficiently scaling transformer inference, arXiv preprint】;GQA通过分组避免了这种分区带来的浪费。因此,我们预期GQA对于更大的模型会是一个特别好的权衡方案。
GQA不适用于编码器。我们注意到,GQA并未应用于编码器的自注意力层;因为编码器的表示是并行计算的,所以内存带宽通常不是其主要瓶颈。
图3展示了MHA T5-Large、MHA T5-XXL以及增量训练比例α=0.05的MQA和GQA-8 XXL模型在所有数据集上的平均性能与平均推理时间的关系。实验结果表明,一个更大的、经过增量训练的MQA模型相比于MHA模型提供了一个更有利的权衡,其质量和推理速度均优于MHA-Large模型。更重要的是,GQA模型在此基础上实现了显著的质量提升,其性能接近MHA-XXL,而速度则接近MQA。表1中包含了所有数据集的完整结果。
本节在三个有代表性的任务子集上进行实验:CNN/Daily Mail(短篇摘要)、MultiNews(长篇摘要)和TriviaQA(问答),以研究不同建模选择的影响。
检查点转换方法。图4比较了不同检查点转换方法的性能。结果显示,均值池化(Mean)的效果最好,其次是选择第一个头(First),最差的是随机初始化(Random)。从直观上看,结果的排序与从预训练模型中保留信息的程度成正比。
增量训练步数。图5展示了T5 XXL模型在使用MQA和GQA时,性能如何随增量训练比例的变化而变化。首先,GQA在转换后(比例为0)就已经达到了合理的性能,而MQA需要经过增量训练才能变得有效。MQA和GQA都从5%的增量训练中获益,而增加到10%时收益递减。
分组数量。图6展示了GQA分组数量对推理速度的影响。对于更大的模型,KV缓存带来的内存带宽开销约束较小,同时由于头数增加,键值大小的缩减更为明显。因此,将分组数从1(MQA)增加时,最初只会导致适度的速度下降,但随着分组数接近MHA,成本会越来越高。我们选择8个分组作为有利的折中点。
大型语言模型在推理时成本高昂,主要原因是加载键和值(keys and values)时产生的内存带宽开销。多查询注意力(MQA)通过减少这种开销来降低成本,但代价是模型容量和质量的下降。本文提出了一种方法,可以用原始预训练计算成本的一小部分,将多头注意力(MHA)模型转换为多查询模型。此外,我们引入了分组查询注意力(GQA),它是MQA和MHA的一种插值方法,能够在保持与MQA相当的推理速度的同时,实现接近MHA的模型质量。
MQA在微调中的不稳定性。我们发现多查询注意力(MQA)在微调期间可能导致训练不稳定,尤其是在与长输入任务结合时。我们从头开始训练了多个使用MQA的T5-Large模型。在每种情况下,预训练过程都遭受了频繁的损失尖峰(loss spikes),并且最终模型在对长输入任务进行微调时立即发散(diverged)。
增量训练的改善及GQA的稳定性。经过增量训练的MQA模型更为稳定,但仍然表现出高方差。因此,对于在不稳定任务上的MQA模型,我们报告了三次微调运行的平均性能。然而,经过增量训练的分组查询注意力(GQA)模型似乎是稳定的,所以我们没有进一步探究MQA不稳定的根本原因。