核心问题: 尽管Transformer模型【35, Vaswani, A. 等人, Attention is all you need, NIPS 2017】在深度学习领域取得了巨大成功,但其在处理长序列时面临着二次方复杂度的限制。这一限制阻碍了模型处理长期信息的能力,而这在许多应用中至关重要。
现有方法缺陷: 许多旨在通过高效注意力机制加速Transformer的现有技术,虽然在理论上可能具有线性复杂度,但在实践中存在一个或多个以下缺陷:
1. 质量不佳 (Inferior Quality): 与经过简单调整后的增强版Transformer(本文称为Transformer++)相比,许多高效注意力方法存在显著的性能下降,这种下降超过了其效率优势。
2. 实践中的开销 (Overhead in Practice): 这些方法通常会使Transformer层变得复杂,并需要大量的内存重格式化操作,导致其在GPU或TPU等加速器上的理论复杂度与实际速度之间存在巨大差距。
3. 自回归训练效率低下 (Inefficient Auto-regressive Training): 大多数注意力线性化技术虽然在推理时解码速度快,但在自回归任务(如语言建模)的训练中却非常缓慢。这主要是因为它们需要在大量步骤上进行类似RNN的顺序状态更新,无法充分利用现代加速器的并行计算能力。
研究目标与创新点: 本文旨在开发一个名为FLASH的新模型家族,该模型首次在质量上与完全增强的Transformer相媲美,同时在现代加速器上真正实现了对上下文大小的线性可扩展性。其开发分为两步:
提出门控注意力单元 (Gated Attention Unit, GAU):
提出混合分块注意力 (Mixed Chunk Attention):
主要成果: 实验表明,在512到8K的上下文长度范围内,FLASH在质量上与完全增强的Transformer (Transformer++)相当,同时在Wiki-40B上的自回归语言建模任务中实现了高达4.9倍的训练加速,在PG-19上实现了高达12.1倍的加速,在C4上的掩码语言建模任务中实现了4.8倍的加速。
本节介绍了门控注意力单元(GAU),一个比Transformer层更简单但性能更优的层。尽管GAU对上下文长度仍具有二次方复杂度,但它为第3节将要介绍的近似方法提供了更好的基础。首先介绍相关的层:
Vanilla MLP: 对于输入表示$X \in R^{T \times d}$,Transformer的MLP输出可表示为 $O = \phi(XW_u)W_o$,其中$W_u \in R^{d \times e}$,$W_o \in R^{e \times d}$,$d$是模型大小,$e$是扩展的中间层大小,$\phi$是逐元素的激活函数。
门控线性单元 (Gated Linear Unit, GLU): 这是对MLP的改进,增加了门控机制【7, Dauphin, Y. N. 等人, Language modeling with gated convolutional networks, ICML 2017】。GLU在许多场景下被证明是有效的【31, Shazeer, N., GLU variants improve transformer, CoRR 2020】、【23, Narang, S. 等人, Do transformer modifications transfer across implementations and applications?, CoRR 2021】,并被用于最先进的Transformer语言模型中【10, Du, N. 等人, Glam: Efficient scaling of language models with mixtureof-experts, CoRR 2021】、【34, Thoppilan, R. 等人, Lamda: Language models for dialog applications, CoRR 2022】。
其中$\otimes$代表逐元素相乘。在GLU中,每个表示$u_i$都由与同一token关联的另一个表示$v_i$进行门控。
门控注意力单元 (Gated Attention Unit, GAU): 核心思想是将注意力和GLU统一为一个层,并尽可能共享它们的计算。这不仅提高了参数/计算效率,还自然地实现了一种强大的注意力门控机制。具体来说,GAU将GLU中的公式(2)泛化为:
其中$A \in R^{T \times T}$包含token间的注意力权重。与GLU总是使用$v_i$来门控$u_i$(两者都与同一token关联)不同,GAU用一个可能更相关的表示$\hat{v}_i = \sum_j a_{ij}v_j$来替代$v_i$,这个$\hat{v}_i$是通过注意力从所有可用的token中“检索”得到的。当A是单位矩阵时,该公式退化为GLU。
GAU中的简化注意力机制: 与Liu等人【22, Liu, H. 等人, Pay attention to mlps, NeurIPS 2021】的发现一致,门控机制的存在使得我们可以使用比MHSA简单/弱得多的注意力机制而不会损失质量。GAU的注意力计算如下:
其中$Z$是共享表示(维度$s \ll d$),$Q$和$K$是两个廉价的变换,它们对$Z$应用逐维度的缩放和偏移(类似于LayerNorm中的可学习变量),$b$是相对位置偏置。我们还发现,在GAU中,MHSA中的softmax可以简化为一个常规的激活函数。GAU层及其伪代码如图2所示。
GAU的参数效率: 与Transformer的MHSA带有$4d^2$个参数不同,GAU的注意力机制在GLU的基础上只增加了一个小的密集矩阵$W_z$($ds$个参数),以及$Q$和$K$中的少量缩放和偏移参数。通过为GAU设置$e = 2d$,这种紧凑的设计允许我们用两个GAU层替换一个Transformer块(MLP/GLU + MHSA),同时保持相似的模型大小和训练速度。
GAU与Transformer的性能对比: 图3显示,在不同的模型大小下,GAU在TPU上与Transformer(MHSA + MLP/GLU)具有竞争力。这些实验是在相对较短的上下文长度(512)下进行的。在第4节中可以看到,当上下文长度更长时,由于其注意力容量的减少,GAU的性能甚至更好。
层消融研究: 表1和表2的消融研究表明,GAU和Transformer各自在其设计空间内都是局部最优的。这意味着将GAU的组件(如单头注意力)直接用于Transformer,或者将Transformer的组件(如多头注意力)用于GAU,都会导致性能下降。
从第2节中得到的两个观察结果,启发我们将GAU扩展到长序列建模:
- 首先,GAU中的门控机制允许使用更弱的(单头、无softmax)注意力而不损失质量。如果将这一直觉进一步应用于使用注意力建模长序列,GAU也可能提升近似(弱)注意力机制(如局部、稀疏和线性化注意力)的有效性。
- 此外,使用GAU时,注意力模块的数量自然增加了一倍——根据第2节,MLP+MHSA ≈ 2×GAU的成本。由于近似注意力通常需要更多层来捕捉完整的依赖关系【6, Dai, Z. 等人, Transformer-xl: Attentive language models beyond a fixed-length context, arXiv 2019】、【4, Child, R. 等人, Generating long sequences with sparse transformers, arXiv 2019】,这一特性也使得GAU在处理长序列时更具吸引力。
部分注意力 (Partial Attention): 这类方法尝试用不同的部分/稀疏模式来近似全注意力矩阵。包括局部窗口【6, Dai, Z. 等人, Transformer-xl, arXiv 2019】、【26, Rae, J. W. 等人, Compressive transformers for long-range sequence modelling, arXiv 2019】,局部+稀疏【4, Child, R. 等人, Generating long sequences with sparse transformers, arXiv 2019】、【21, Li, S. 等人, Enhancing the locality and breaking the memory bottleneck of transformer on time series forecasting, NIPS 2019】、【2, Beltagy, I. 等人, Longformer: The long-document transformer, arXiv 2020】、【37, Zaheer, M. 等人, Big bird: Transformers for longer sequences, NeurIPS 2020】,轴向【15, Ho, J. 等人, Axial attention in multidimensional transformers, arXiv 2019】、【16, Huang, Z. 等人, Ccnet: Criss-cross attention for semantic segmentation, ICCV 2019】,以及通过哈希【20, Kitaev, N. 等人, Reformer: The efficient transformer, arXiv 2020】或聚类【30, Roy, A. 等人, Efficient content-based sparse attention with routing transformers, TACL 2021】学习的模式。这些方法的关键问题是它们涉及大量不规则或规则的内存重格式化操作(如gather、scatter、slice和concatenation),这对现代大规模并行加速器(特别是像TPU这样的专用ASIC)不友好。因此,它们的实际收益(速度和RAM效率)在很大程度上取决于加速器的选择,并且通常落后于理论分析。
线性注意力 (Linear Attention): 另一类研究通过分解注意力矩阵然后重排矩阵乘法顺序来线性化注意力计算【5, Choromanski, K. 等人, Rethinking attention with performers, arXiv 2020】、【36, Wang, S. 等人, Linformer: Self-attention with linear complexity, arXiv 2020】、【19, Katharopoulos, A. 等人, Transformers are rnns: Fast autoregressive transformers with linear attention, ICML 2020】、【25, Peng, H. 等人, Random feature attention, ICLR 2021】。其计算可表示为:
其中$Q, K, V \in R^{T \times d}$。重排计算将关于$T$的复杂度从二次降为线性。
基于现有方法的优缺点,我们提出了混合分块注意力,它融合了部分注意力和线性注意力的优点。
- 准备工作: 输入序列首先被分块为$G$个大小为$C$的非重叠块。然后,根据GAU公式为每个块$g$生成$U_g \in R^{C \times e}$、$V_g \in R^{C \times e}$和$Z_g \in R^{C \times s}$。接下来,从$Z_g$生成四种注意力头:$Q_{quad_g}$、$K_{quad_g}$、$Q_{lin_g}$、$K_{lin_g}$。
块内局部注意力: 首先,对每个长度为$C$的块独立应用局部二次方注意力,生成部分预门控状态:
这部分的复杂度为$O(G \times C^2 \times s) = O(TCs)$,在$C$保持不变的情况下,对$T$是线性的。
块间全局注意力: 此外,采用全局线性注意力机制来捕捉块间的长程交互:
注意公式(7)和(8)中的求和是在块级别上执行的。对于因果(自回归)情况,这将token级线性注意力中cumsum的元素数量减少了$C$倍(实验中$C$通常为256),从而显著提高了训练速度。
最终组合: 最后,将$\hat{V}_{quad_g}$和$\hat{V}_{lin_g}$相加,然后进行门控和后注意力投影,类似于公式(3):
混合分块注意力实现简单,其伪代码在Code 1中给出。
快速自回归训练: 由于分块,自回归情况下的顺序依赖性从标准线性注意力中的$T$步减少到分块版本中的$G = T/C$步。因此,当块大小为{128, 256, 512}时,自回归训练变得非常快。同时,该模型仍然享有每步解码$O(Cd^2)$的恒定内存和计算开销。
关于非重叠局部注意力: 我们的方法中使用非重叠的块。理论上,任何部分注意力变体都可以用来替代非重叠的局部注意力。例如,我们探索了让每个块额外关注其邻近块,类似于Longformer【2, Beltagy, I. 等人, Longformer, arXiv 2020】和BigBird【37, Zaheer, M. 等人, Big bird, NeurIPS 2020】。虽然重叠的局部注意力可以持续提高质量,但它也引入了许多内存重格式化操作,明显损害了实际运行速度。在TPU上的初步实验中,我们发现使用重叠局部注意力的成本效益权衡可能不如增加更多层。
与Combiner的联系: 与我们的方法类似,Combiner【29, Ren, H. 等人, Combiner: Full attention transformer with sparse computation cost, NeurIPS 2021】也将序列分割成非重叠的块,并在每个块内使用二次方局部注意力。关键区别在于如何总结和组合远程信息。我们的混合分块注意力允许每个块有更大的有效内存,从而可能带来更好的质量。更详细的讨论见附录A。
本文提出了FLASH,一个旨在解决现有高效Transformer变体在质量和实际速度方面问题的实用解决方案。这一目标通过设计一个高性能层(门控注意力单元GAU)并将其与一个对加速器友好的近似策略(混合分块注意力)相结合来实现。在双向和自回归语言建模任务上的实验表明,FLASH在质量(困惑度)上与完全增强的Transformer相当,而训练速度则比现有最先进的模型快得多。未来的工作包括研究这个新模型家族的缩放定律以及其在更多下游任务上的表现。
信息压缩方式: Combiner【29, Ren, H. 等人, Combiner, NeurIPS 2021】将每个块总结为摘要键和值向量$K_{sum}, V_{sum} \in R^{T/C \times d}$,并将它们拼接到局部二次方注意力中。这相当于将一个大小为$C$的块压缩成一个$O(d)$的向量。而我们的分块线性注意力部分将每个块压缩成一个大小为$O(sd)$的矩阵$K_{lin_h}^T V_h$,其大小是Combiner的$s$倍。这意味着我们的方法压缩程度更低,保留了更多信息,因此可能具有质量优势。
信息组合方式: Combiner复用二次方注意力来组合不同块的压缩信息,而我们的分块线性注意力仅执行(累积)求和。尽管可以模仿Combiner的方式构建一个额外的$[T/C \times T/C]$注意力矩阵来组合块摘要,但这会使模型设计复杂化,并且要求模型存储和关注所有块摘要,导致自回归解码的复杂度增加到$O((C+T/C)d^2)$,不再是常数。因此,默认配置中未包含此特性。
实验中使用的所有模型的详细规格总结在表8、9和10中。FLASH-Quad和FLASH使用SiLU/Swish作为非线性激活函数,因为它在我们的模型中略优于GELU。在一些掩码语言模型中使用了ScaleNorm,因为它在TPU-v4上比LayerNorm运行稍快且不影响模型质量。
- C4上的MLM模型配置 (表8)
- Wiki-40B上的LM模型配置 (表9)
- PG-19上的LM模型配置 (表10)
自回归训练的低效率问题不仅限于TPU。如表11所示,在单个Nvidia V100 GPU上,Performer由于需要对所有token进行顺序的累积求和,其延迟最高。相比之下,当上下文长度超过1024时,本文提出的FLASH延迟最低,证明了混合分块注意力机制的有效性。
块大小的选择会影响FLASH的质量和训练成本。在极端情况下,当块大小等于上下文长度时,FLASH退化为FLASH-Quad;当块大小为1时,它变成一个低效的线性注意力模型。图8展示了在1K到8K的上下文长度下,四种不同块大小在质量和训练成本之间的权衡。
下面展示了FLASH-Quad和FLASH的详细实现代码。
- 相对位置偏置伪代码 (Code 4)
def rel_pos_bias(n):
"""Relative position bias."""
if n < 512:
# Construct Toeplitz matrix directly when the sequence length is less than 512.
w = tf.get_variable(
'weight',
shape=[2 * n - 1],
dtype=tf.float32,
initializer=WEIGHT_INITIALIZER)
t = tf.pad(w, [[0, n]])
t = tf.tile(t, [n])
t = t[..., :-n]
t = tf.reshape(t, [n, 3 * n - 2])
r = (2 * n - 1) // 2
t = t[..., r:-r]
else:
# Construct Toeplitz matrix using RoPE when the sequence length is over 512.
a = tf.get_variable(
'a',
shape=[128],
dtype=dtype,
initializer=WEIGHT_INITIALIZER)
b = tf.get_variable(
'b',
shape=[128],
dtype=dtype,
initializer=WEIGHT_INITIALIZER)
a = rope(tf.tile(a[None, :], [n, 1]), axis=0)
b = rope(tf.tile(b[None, :], [n, 1]), axis=0)
t = tf.einsum('mk,nk->mn', a, b)
return t
def _get_scaledsin(embeddings):
"""Create sinusoidal position embedding with a scaling factor."""
hidden_size = int(embeddings.shape[-1])
pos = tf.range(tf.shape(embeddings)[1])
pos = tf.cast(pos, tf.float32)
half_d = hidden_size // 2
freq_seq = tf.cast(tf.range(half_d), tf.float32) / float(half_d)
inv_freq = 10000 ** -freq_seq
sinusoid = tf.einsum('s,d->sd', pos, inv_freq)
scaledsin = tf.concat([tf.sin(sinusoid), tf.cos(sinusoid)], axis=-1)
scalar = tf.get_variable(
'scaledsin_scalar',
shape=(),
initializer=tf.constant_initializer(1 / hidden_size ** 0.5))
scaledsin *= scalar
return scaledsin