文章标题:Transformer即RNN:采用线性注意力的快速自回归Transformer
作者/机构:Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, François Fleuret
核心问题:
标准的Transformer模型虽然在多种任务中表现出色,但其核心组件自注意力(self-attention)的计算和内存复杂性与输入序列长度 N 呈二次方关系,即 $O(N^2)$。这使得处理非常长的序列时,其计算成本过高、速度过慢,从而限制了模型的上下文长度,影响了时间连贯性和捕捉长期依赖的能力。尽管现有的一些高效Transformer方法(如稀疏分解、局部敏感哈希)在训练上降低了复杂性,但它们并未加速自回归推理过程。
研究目标:
本文旨在提出一种线性Transformer模型,该模型能够显著减少内存占用,并将计算复杂度降低至与上下文长度呈线性关系,即 $O(N)$。同时,该模型需要能大幅提升自回归推理的速度。
创新点:
1. 线性化自注意力机制:作者将自注意力表达为核函数特征图(kernel feature maps)的线性点积。通过利用矩阵乘法的结合律,成功将计算复杂度从 $O(N^2)$ 降低到 $O(N)$。
2. 高效的因果掩码:本文提出了一种适用于线性化注意力的因果掩码(causal masking)实现,其同样具有线性的复杂度和恒定的内存占用。
3. 揭示Transformer与RNN的关系:该线性化框架揭示了自回归Transformer与循环神经网络(RNN)之间的内在联系。基于此,作者将Transformer层重写为RNN形式,使其能够在自回归推理任务中实现数千倍的速度提升。
本节形式化地提出了线性Transformer。通过将传统的softmax注意力改为基于特征图的点积注意力,实现了更好的时间和内存复杂度,并获得了一个能像RNN一样以线性时间进行序列生成的因果模型。
Transformer层级结构:一个Transformer定义为一个函数 $T: R^{N \times F} \rightarrow R^{N \times F}$,由L个Transformer层 $T_1(\cdot), \dots, T_L(\cdot)$ 组合而成。
其中函数 $f_l(\cdot)$ 独立地转换每个特征,通常由一个小型两层前馈网络实现。$A_l(\cdot)$ 是自注意力函数,是Transformer中唯一跨序列操作的部分。
标准自注意力:自注意力函数 $A_l(\cdot)$ 为每个位置计算所有其他位置特征表示的加权平均值,权重与表示之间的相似度得分成正比。具体地,输入序列x通过三个矩阵 $W_Q \in R^{F \times D}$, $W_K \in R^{F \times D}$, $W_V \in R^{F \times M}$ 投影得到相应的表示Q, K, V。所有位置的输出 $A_l(x) = V'$ 计算如下:
注意,上式中的softmax函数是按行应用于 $QK^T$ 的。Q、K、V通常被称为“查询”、“键”和“值”。
通用注意力形式:我们可以为任何相似性函数编写一个通用的注意力方程,如下所示,其中用i索引矩阵表示取其第i行作为一个向量:
Softmax注意力的等价性:如果我们将相似性函数替换为 $sim(q, k) = exp(\frac{q^Tk}{\sqrt{d_k}})$,那么方程3就等价于方程2。
基于核的注意力:方程2中的注意力定义是通用的,可用于定义其他注意力实现,如多项式注意力或RBF核注意力【37, Tsai et al. Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel. 2019. EMNLP-IJCNLP】。为了让方程3定义一个注意力函数,我们只需要对 $sim(\cdot)$ 施加非负约束,这包括了所有的核函数 $k(x, y): R^{2 \times F} \rightarrow R^+$。
核函数重写注意力:给定这样一个带有特征表示 $\phi(x)$ 的核函数,我们可以将方程2重写为:
利用矩阵乘法结合律简化:通过利用矩阵乘法的结合律,我们可以进一步将其简化为:
分子向量化形式:上述方程的分子以向量化形式书写时更易理解:
注意,特征映射 $\phi(\cdot)$ 是按行应用于矩阵Q和K的。
复杂度分析:从方程2可以看出,softmax注意力的计算成本与序列长度N的平方成正比,即 $O(N^2)$。内存需求也是如此,因为必须存储完整的注意力矩阵来计算关于查询、键和值的梯度。相比之下,我们提出的线性Transformer(方程5)的时间和内存复杂度为 $O(N)$,因为我们可以计算一次 $\sum_{j=1}^{N} \phi(K_j)V_j^T$ 和 $\sum_{j=1}^{N} \phi(K_j)$,并对每个查询重用它们。
自回归模型的因果掩码:Transformer架构可以通过掩蔽注意力计算来高效训练自回归模型,使得第i个位置只能受位置j(当且仅当 $j \le i$)的影响。形式上,这种因果掩码将方程3改变为:
线性化因果掩码:遵循3.2节的思路,我们将掩码注意力线性化如下:
引入累加状态变量:通过引入 $S_i$ 和 $Z_i$ 如下:
简化后的线性因果注意力:我们可以将方程9简化为:
注意,$S_i$ 和 $Z_i$ 可以在常数时间内从 $S_{i-1}$ 和 $Z_{i-1}$ 计算得出,从而使得带有因果掩码的线性Transformer的计算复杂度与序列长度呈线性关系。
梯度公式:给定分子 $\bar{V}_i$ 和标量损失函数L相对于分子的梯度 $\nabla_{\bar{V}} L$,我们可以推导出 $\nabla_{\phi(Q_i)}L$, $\nabla_{\phi(K_i)}L$ 和 $\nabla_{V_i}L$ 如下:
复杂度和实现:方程9和13-15中的累加和项可以在线性时间内计算,并且相对于序列长度只需要常数内存。对于给定的C维特征图,这产生了一个计算复杂度为 $O(NCM)$、内存为 $O(N\max(C, M))$ 的算法。算法1中给出了分子前向和后向传播的伪代码实现。
对比模型:
softmax: 带softmax注意力的标准Transformer【38, Vaswani et al. Attention is all you need. 2017. NIPS】。Reformer (lsh-X): SOTA的加速Transformer架构【14, Kitaev et al. Reformer: The efficient transformer. 2020. arXiv】,使用PyTorch重现版本,不使用可逆层。X表示哈希轮数。linear: 本文提出的线性Transformer。软件配置:
linear模型使用公式7的特征图。硬件配置:
实验任务与模型参数:
softmax批大小1,linear和reformer批大小4。所有模型训练7天。softmax模型相同,并且优于由哈希引入噪声的lsh模型。softmax模型的内存和时间消耗随序列长度呈二次方增长;而linear和Reformer模型则呈线性增长。在所有配置下,linear模型都比基线模型速度更快、内存占用更少。MNIST (表1, 图3):linear模型在性能(bits/dim)上与softmax模型几乎持平(0.83 vs 0.82),但在图像生成吞吐量上快了300多倍(142张/秒 vs 0.45张/秒)。这是因为其恒定的内存需求使其能在单GPU上同时生成10000张图像。从图3的样本来看,生成图像清晰,图像补全效果令人信服。
CIFAR-10 (表2, 图4):由于序列更长,linear模型的优势更加明显。在固定的7天训练时间内,linear模型完成了softmax模型3倍的训练轮数(150 vs 49),并取得了更好的困惑度(3.14 vs 3.26)。在生成吞吐量上,linear模型比softmax快4000多倍(17.85张/秒 vs 0.004张/秒)。图4展示了模型生成的图像具有空间一致性,且能令人信服地完成图像补全。
linear模型在音素错误率(PER)和训练速度上均大幅优于LSTM和Reformer基线。softmax模型虽然取得了最低的PER(7.3%),但其训练速度明显慢于其他模型,linear模型每个epoch的训练速度是其3倍以上(290秒 vs 900秒)。本文提出了线性Transformer,通过利用矩阵乘法的结合律,成功地将自注意力的计算和内存成本降低到与序列长度呈线性关系。研究表明,该模型可以与因果掩码结合使用,并保持其线性的渐进复杂度。最终,本文将Transformer模型表达为一种循环神经网络(RNN),使其能够在自回归任务上实现数千倍的推理加速。
未来工作展望:
1. 深入研究Transformer与RNN的关系:这一特性为未来研究RNN和Transformer中信息的存储与检索机制开辟了多种方向。
2. 探索新的特征图:另一个值得探索的研究方向是为线性注意力选择不同的特征图。例如,使用随机傅里叶特征来近似RBF核,可能允许我们直接使用以softmax注意力预训练的模型。
目标:本节详细推导了因果掩码线性Transformer的梯度,并证明它们可以在线性时间和常数内存中计算。具体地,是推导一个标量损失函数相对于以下方程分子的梯度:
分母和整个分数的梯度可以由自动求导系统高效处理。不失一般性,我们可以假设Q和K已经包含了由$\phi(\cdot)$映射过的向量。
分子定义:给定分子
和 $\nabla_{\bar{V}} L$,目标是计算 $\nabla_Q L, \nabla_K L$ 和 $\nabla_V L$。
梯度推导过程:首先,将上述方程以非向量形式表示单个元素:
随后,通过对任意 $Q_{lt}$ 求偏导来推导Q的梯度:
将上式写成梯度矩阵的乘积形式,即证明了主论文中的方程13:
对于K和V,由于 $K_j$ 会影响所有 $i \ge j$ 的 $\bar{V}_i$,因此需要对i进行求和。对 $K_{lt}$ 的偏导如下:
其向量化形式证明了主论文中的方程14:
对V的梯度推导过程类似,证明了方程15。值得注意的是,Q和K的梯度累加和矩阵大小相同,但一个是从1到N的正向累加,另一个是从N到1的反向累加,类似于RNN中的BPTT。
linear模型的收敛情况与softmax相当,显著优于两种reformer变体。lsh-1和linear模型完成的epoch数远多于softmax和lsh-4,并取得了更好的性能。随着序列长度进一步增加,这种差距预计会更大。softmax在收敛性上显著优于Reformer和linear。但linear每个epoch的速度快3倍,这意味着它完成的epoch数大约是softmax的4倍。尽管softmax在该任务上更优,但linear在收敛和最终性能上均显著优于Reformer。C.1. 有状态的softmax注意力:为了进行更公平的比较,本节创建了一个名为stateful-softmax的基线,它将softmax自回归transformer实现为一个循环模型,即缓存所有的键和值。这个循环模型的状态大小与序列长度成正比,这与我们提出的状态大小固定的模型有本质区别。如表4所示,stateful-softmax比普通transformer快得多,但其复杂度仍然是二次的。对于CIFAR-10,我们的linear模型仍然比它快50多倍。
C.2. 统一批次大小:本节评估了生成单个图像的延迟(latency),即将批次大小设为1。结果如表5所示,所有方法在GPU上都未充分利用资源,吞吐量远低于表4。linear transformer在所有方法中速度最快,在CIFAR-10上生成一张图片比softmax快近6.6倍。一个有趣的现象是,linear模型在CPU上的运行速度在所有情况下都比GPU快,这是因为其RNN式的注意力计算成本极低,主要的计算瓶颈变成了不可避免的序列外层循环。
Reformer模型在无条件样本中的变化明显较少。此外,所有模型在图像补全任务上的表现都比无条件生成要好得多,表明这是一个相对容易的任务。