Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

文章标题:Transformer即RNN:采用线性注意力的快速自回归Transformer
作者/机构:Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, François Fleuret


A1 主要贡献

核心问题
标准的Transformer模型虽然在多种任务中表现出色,但其核心组件自注意力(self-attention)的计算和内存复杂性与输入序列长度 N 呈二次方关系,即 $O(N^2)$。这使得处理非常长的序列时,其计算成本过高、速度过慢,从而限制了模型的上下文长度,影响了时间连贯性和捕捉长期依赖的能力。尽管现有的一些高效Transformer方法(如稀疏分解、局部敏感哈希)在训练上降低了复杂性,但它们并未加速自回归推理过程。

研究目标
本文旨在提出一种线性Transformer模型,该模型能够显著减少内存占用,并将计算复杂度降低至与上下文长度呈线性关系,即 $O(N)$。同时,该模型需要能大幅提升自回归推理的速度。

创新点
1. 线性化自注意力机制:作者将自注意力表达为核函数特征图(kernel feature maps)的线性点积。通过利用矩阵乘法的结合律,成功将计算复杂度从 $O(N^2)$ 降低到 $O(N)$。
2. 高效的因果掩码:本文提出了一种适用于线性化注意力的因果掩码(causal masking)实现,其同样具有线性的复杂度和恒定的内存占用。
3. 揭示Transformer与RNN的关系:该线性化框架揭示了自回归Transformer与循环神经网络(RNN)之间的内在联系。基于此,作者将Transformer层重写为RNN形式,使其能够在自回归推理任务中实现数千倍的速度提升。


A3 背景知识/关键Observation/设计原则

2.1. 高效的Transformer

2.2. 理解自注意力

2.3. 线性化的Softmax


A2 方法细节

本节形式化地提出了线性Transformer。通过将传统的softmax注意力改为基于特征图的点积注意力,实现了更好的时间和内存复杂度,并获得了一个能像RNN一样以线性时间进行序列生成的因果模型。

3.1. Transformer

3.2. 线性化注意力

3.2.1. 特征图与计算成本

3.3. 因果掩码

3.3.1. 梯度计算
3.3.2. 训练与推理

3.4. Transformer即RNN


A4 实验环境


A4 实验结果

4.1. 合成任务

4.2. 图像生成

4.3. 自动语音识别


A5 结论

本文提出了线性Transformer,通过利用矩阵乘法的结合律,成功地将自注意力的计算和内存成本降低到与序列长度呈线性关系。研究表明,该模型可以与因果掩码结合使用,并保持其线性的渐进复杂度。最终,本文将Transformer模型表达为一种循环神经网络(RNN),使其能够在自回归任务上实现数千倍的推理加速。

未来工作展望
1. 深入研究Transformer与RNN的关系:这一特性为未来研究RNN和Transformer中信息的存储与检索机制开辟了多种方向。
2. 探索新的特征图:另一个值得探索的研究方向是为线性注意力选择不同的特征图。例如,使用随机傅里叶特征来近似RBF核,可能允许我们直接使用以softmax注意力预训练的模型。


A6 附录

A. 梯度推导

B. 训练过程

C. 图像生成吞吐量讨论

D. 图像生成的定性结果