Transformer Quality in Linear Time

A1 主要贡献

核心问题: 尽管Transformer模型【35, Vaswani, A. 等人, Attention is all you need, NIPS 2017】在深度学习领域取得了巨大成功,但其在处理长序列时面临着二次方复杂度的限制。这一限制阻碍了模型处理长期信息的能力,而这在许多应用中至关重要。

现有方法缺陷: 许多旨在通过高效注意力机制加速Transformer的现有技术,虽然在理论上可能具有线性复杂度,但在实践中存在一个或多个以下缺陷:
1. 质量不佳 (Inferior Quality): 与经过简单调整后的增强版Transformer(本文称为Transformer++)相比,许多高效注意力方法存在显著的性能下降,这种下降超过了其效率优势。
2. 实践中的开销 (Overhead in Practice): 这些方法通常会使Transformer层变得复杂,并需要大量的内存重格式化操作,导致其在GPU或TPU等加速器上的理论复杂度与实际速度之间存在巨大差距。
3. 自回归训练效率低下 (Inefficient Auto-regressive Training): 大多数注意力线性化技术虽然在推理时解码速度快,但在自回归任务(如语言建模)的训练中却非常缓慢。这主要是因为它们需要在大量步骤上进行类似RNN的顺序状态更新,无法充分利用现代加速器的并行计算能力。

图1:在Wiki40B上进行自回归语言建模时,FLASH相对于vanilla Transformer (TFM)和增强型Transformer (TFM++)的TPU-v4训练加速比。所有模型大小均在110M左右,使用2^18个token的批次大小训练125K步。
图1:在Wiki40B上进行自回归语言建模时,FLASH相对于vanilla Transformer (TFM)和增强型Transformer (TFM++)的TPU-v4训练加速比。所有模型大小均在110M左右,使用2^18个token的批次大小训练125K步。

研究目标与创新点: 本文旨在开发一个名为FLASH的新模型家族,该模型首次在质量上与完全增强的Transformer相媲美,同时在现代加速器上真正实现了对上下文大小的线性可扩展性。其开发分为两步:

  1. 提出门控注意力单元 (Gated Attention Unit, GAU):

    • 首先,本文设计了一种名为GAU的新层,如图2所示。GAU通过引入门控机制来减轻自注意力的负担。
    • 相比Transformer层,GAU层计算成本更低,且其质量对注意力机制的精确度依赖性更小。
    • 实验证明,使用一个较小的单头、无softmax的注意力机制的GAU,其性能与标准Transformer相当。
    • GAU虽然仍具有二次方复杂度,但它削弱了注意力的作用,为后续进行低质量损失的近似计算创造了条件。
  2. 提出混合分块注意力 (Mixed Chunk Attention):

    • 其次,本文提出了一种高效的方法来近似GAU中的二次方注意力,从而得到一个具有线性复杂度的层变体。
    • 核心思想是将token分组为块(chunks),然后在块内使用精确的二次方注意力,在块间使用快速的线性注意力,如图4所示。
    • 这种方法可以自然地转化为加速器友好的实现,仅需几行代码更改即可在实践中实现线性可扩展性。

图2:(a)一个增强的Transformer层,由门控线性单元(GLU)和多头自注意力(MHSA)两个块组成。(b)我们提出的门控注意力单元(GAU)。(c)门控注意力单元的伪代码。为了简洁,(a)和(b)中省略了残差分支上的跳跃连接和输入归一化。
图2:(a)一个增强的Transformer层,由门控线性单元(GLU)和多头自注意力(MHSA)两个块组成。(b)我们提出的门控注意力单元(GAU)。(c)门控注意力单元的伪代码。为了简洁,(a)和(b)中省略了残差分支上的跳跃连接和输入归一化。

主要成果: 实验表明,在512到8K的上下文长度范围内,FLASH在质量上与完全增强的Transformer (Transformer++)相当,同时在Wiki-40B上的自回归语言建模任务中实现了高达4.9倍的训练加速,在PG-19上实现了高达12.1倍的加速,在C4上的掩码语言建模任务中实现了4.8倍的加速。

A2 方法细节

2. 门控注意力单元 (Gated Attention Unit)

本节介绍了门控注意力单元(GAU),一个比Transformer层更简单但性能更优的层。尽管GAU对上下文长度仍具有二次方复杂度,但它为第3节将要介绍的近似方法提供了更好的基础。首先介绍相关的层:


其中$\otimes$代表逐元素相乘。在GLU中,每个表示$u_i$都由与同一token关联的另一个表示$v_i$进行门控。

图3:GAU与Transformer在短上下文长度(512)下自回归和掩码语言建模的对比
图3:GAU与Transformer在短上下文长度(512)下自回归和掩码语言建模的对比

表1:各种修改对GAU的影响
表1:各种修改对GAU的影响

表2:各种修改对MHSA的影响
表2:各种修改对MHSA的影响

3. 使用GAU实现快速线性注意力

从第2节中得到的两个观察结果,启发我们将GAU扩展到长序列建模:
- 首先,GAU中的门控机制允许使用更弱的(单头、无softmax)注意力而不损失质量。如果将这一直觉进一步应用于使用注意力建模长序列,GAU也可能提升近似(弱)注意力机制(如局部、稀疏和线性化注意力)的有效性。
- 此外,使用GAU时,注意力模块的数量自然增加了一倍——根据第2节,MLP+MHSA ≈ 2×GAU的成本。由于近似注意力通常需要更多层来捕捉完整的依赖关系【6, Dai, Z. 等人, Transformer-xl: Attentive language models beyond a fixed-length context, arXiv 2019】、【4, Child, R. 等人, Generating long sequences with sparse transformers, arXiv 2019】,这一特性也使得GAU在处理长序列时更具吸引力。

3.1. 现有的线性复杂度变体

3.2. 我们的方法:混合分块注意力 (Mixed Chunk Attention)

基于现有方法的优缺点,我们提出了混合分块注意力,它融合了部分注意力和线性注意力的优点。
- 准备工作: 输入序列首先被分块为$G$个大小为$C$的非重叠块。然后,根据GAU公式为每个块$g$生成$U_g \in R^{C \times e}$、$V_g \in R^{C \times e}$和$Z_g \in R^{C \times s}$。接下来,从$Z_g$生成四种注意力头:$Q_{quad_g}$、$K_{quad_g}$、$Q_{lin_g}$、$K_{lin_g}$。

代码1:混合分块注意力的伪代码。
代码1:混合分块注意力的伪代码。

3.2.1. 讨论

A4 实验环境

A4 实验结果

4.1. 双向语言建模 (MLM on C4)

图5:C4数据集上掩码语言建模的验证集结果 — 所有模型大小约110M,使用2^18个token的批次大小训练125K步。质量以负对数困惑度衡量。
图5:C4数据集上掩码语言建模的验证集结果 — 所有模型大小约110M,使用2^18个token的批次大小训练125K步。质量以负对数困惑度衡量。

4.2. 自回归语言建模 (LM on Wiki-40B & PG-19)

4.3. 微调 (Fine-tuning on TriviaQA)

4.4. 消融研究

A5 结论

本文提出了FLASH,一个旨在解决现有高效Transformer变体在质量和实际速度方面问题的实用解决方案。这一目标通过设计一个高性能层(门控注意力单元GAU)并将其与一个对加速器友好的近似策略(混合分块注意力)相结合来实现。在双向和自回归语言建模任务上的实验表明,FLASH在质量(困惑度)上与完全增强的Transformer相当,而训练速度则比现有最先进的模型快得多。未来的工作包括研究这个新模型家族的缩放定律以及其在更多下游任务上的表现。

A6 附录

A. 与Combiner的联系

B. 实验设置

B.1. 超参数

B.2. 模型规格

实验中使用的所有模型的详细规格总结在表8、9和10中。FLASH-Quad和FLASH使用SiLU/Swish作为非线性激活函数,因为它在我们的模型中略优于GELU。在一些掩码语言模型中使用了ScaleNorm,因为它在TPU-v4上比LayerNorm运行稍快且不影响模型质量。
- C4上的MLM模型配置 (表8)

表8:第4节中C4数据集上MLM实验的模型配置。
表8:第4节中C4数据集上MLM实验的模型配置。

- Wiki-40B上的LM模型配置 (表9)
表9:第4节中Wiki-40B数据集上LM实验的模型配置。
表9:第4节中Wiki-40B数据集上LM实验的模型配置。

- PG-19上的LM模型配置 (表10)
表10:第4节中PG-19数据集上LM实验的模型配置。
表10:第4节中PG-19数据集上LM实验的模型配置。

C. 额外的实验结果

C.1. GPU上的自回归训练

自回归训练的低效率问题不仅限于TPU。如表11所示,在单个Nvidia V100 GPU上,Performer由于需要对所有token进行顺序的累积求和,其延迟最高。相比之下,当上下文长度超过1024时,本文提出的FLASH延迟最低,证明了混合分块注意力机制的有效性。

表11:在单个Nvidia Tesla V100 GPU上,Wiki-40B自回归语言建模每个训练步骤的延迟比较 — 延迟以毫秒报告。OOM代表CUDA内存不足错误。Performer-Matmul使用矩阵乘法实现累积和(cumsum)。
表11:在单个Nvidia Tesla V100 GPU上,Wiki-40B自回归语言建模每个训练步骤的延迟比较 — 延迟以毫秒报告。OOM代表CUDA内存不足错误。Performer-Matmul使用矩阵乘法实现累积和(cumsum)。

C.2. MLM和LM的表格化结果

C.3. 块大小的消融研究

块大小的选择会影响FLASH的质量和训练成本。在极端情况下,当块大小等于上下文长度时,FLASH退化为FLASH-Quad;当块大小为1时,它变成一个低效的线性注意力模型。图8展示了在1K到8K的上下文长度下,四种不同块大小在质量和训练成本之间的权衡。

图8:FLASH在1K到8K上下文长度下块大小(C)的消融研究。(a) 上下文长度 = 1024 (b) 上下文长度 = 2048 (c) 上下文长度 = 4096 (d) 上下文长度 = 8192
图8:FLASH在1K到8K上下文长度下块大小(C)的消融研究。(a) 上下文长度 = 1024 (b) 上下文长度 = 2048 (c) 上下文长度 = 4096 (d) 上下文长度 = 8192

D. FLASH-Quad和FLASH的伪代码

下面展示了FLASH-Quad和FLASH的详细实现代码。
- 相对位置偏置伪代码 (Code 4)

def rel_pos_bias(n):
    """Relative position bias."""
    if n < 512:
        # Construct Toeplitz matrix directly when the sequence length is less than 512.
        w = tf.get_variable(
            &#39;weight&#39;,
            shape=[2 * n - 1],
            dtype=tf.float32,
            initializer=WEIGHT_INITIALIZER)
        t = tf.pad(w, [[0, n]])
        t = tf.tile(t, [n])
        t = t[..., :-n]
        t = tf.reshape(t, [n, 3 * n - 2])
        r = (2 * n - 1) // 2
        t = t[..., r:-r]
    else:
        # Construct Toeplitz matrix using RoPE when the sequence length is over 512.
        a = tf.get_variable(
            &#39;a&#39;,
            shape=[128],
            dtype=dtype,
            initializer=WEIGHT_INITIALIZER)
        b = tf.get_variable(
            &#39;b&#39;,
            shape=[128],
            dtype=dtype,
            initializer=WEIGHT_INITIALIZER)
        a = rope(tf.tile(a[None, :], [n, 1]), axis=0)
        b = rope(tf.tile(b[None, :], [n, 1]), axis=0)
        t = tf.einsum(&#39;mk,nk->mn&#39;, a, b)

    return t
def _get_scaledsin(embeddings):
    &quot;&quot;&quot;Create sinusoidal position embedding with a scaling factor.&quot;&quot;&quot;
    hidden_size = int(embeddings.shape[-1])
    pos = tf.range(tf.shape(embeddings)[1])
    pos = tf.cast(pos, tf.float32)
    half_d = hidden_size // 2
    freq_seq = tf.cast(tf.range(half_d), tf.float32) / float(half_d)
    inv_freq = 10000 ** -freq_seq
    sinusoid = tf.einsum(&#39;s,d->sd&#39;, pos, inv_freq)
    scaledsin = tf.concat([tf.sin(sinusoid), tf.cos(sinusoid)], axis=-1)
    scalar = tf.get_variable(
        &#39;scaledsin_scalar&#39;,
        shape=(),
        initializer=tf.constant_initializer(1 / hidden_size ** 0.5))
    scaledsin *= scalar
    return scaledsin