文章标题:快速 Transformer 解码:一个写入头足矣
作者/机构:Noam Shazeer, Google
核心问题与研究目标:Transformer 神经网络序列模型在增量推理(incremental inference)时的速度是一个主要挑战。在现代计算硬件上,其速度受限于为重载注意力层状态所需的大型“键”(keys)和“值”(values)张量而产生的内存带宽。本文旨在提出一种架构变体,以在仅有轻微质量下降的情况下,大幅提升推理速度。
创新点:本文提出了一种名为多查询注意力(multi-query attention)的变体。在这种结构中,键(keys)和值(values)在所有不同的注意力“头”(heads)之间共享。这种设计极大地减小了这些张量的尺寸,从而降低了增量解码过程中的内存带宽需求。实验证明,采用该方法的模型解码速度确实快得多,并且与基线模型相比,模型质量仅有轻微的下降。
神经注意力的基本功能:由 【Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate, 2014】引入的神经注意力是处理可变长度表示的强大工具。一个神经注意力函数接收一个查询向量 q 和一组 m 个不同的(键向量,值向量)对(由矩阵 K 和 V 表示),并生成一个输出向量 y。输出 y 是不同值向量的加权和,其权重通过比较查询向量与各个键向量得出。
点积注意力的计算方式:一种常见的实现方式是,权重由查询向量与不同键向量的点积经过 softmax 函数计算得出。以下代码描述了此过程。
def DotProductAttention(q, K, V):
""" Dot-Product Attention on one query.
Args:
q: a vector with shape [k]
K: a matrix with shape [m, k]
V: a matrix with shape [m, v]
Returns:
y: a vector with shape [v]
"""
logits = tf.einsum("k,mk->m", q, K)
weights = tf.softmax(logits)
return tf.einsum("m,mv->v", weights, V)
Einsum 标记法说明:文中的代码示例使用了在 TensorFlow 和 numpy 中定义的 einsum 标记法,用于任意维度张量间的广义收缩运算。在这种表示法中,一个等式指定了输入和输出张量的维度名称。该计算在数值上等同于将每个输入广播到拥有所有维度的并集,然后按元素相乘,并对所有不在期望输出形状中的维度进行求和。
多头注意力的并行结构:Transformer 序列到序列模型【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】并行使用 h 个不同的注意力层(头)。这 h 个不同层的查询向量(query vectors)由输入向量 x 经过 h 个不同的学习线性投影 Pq 得到。同样,键(keys)和值(values)由 m 个不同输入向量的集合 M 经过 h 个不同的学习线性投影 Pk 和 Pv 得到。这 h 个层的输出本身再通过不同的学习线性投影 Po,然后相加。为简化起见,本文假设输入和输出向量具有相同的维度 d。
def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
""" Multi-head Attention on one query.
Args:
x: a vector with shape [d]
M: a matrix with shape [m, d]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns:
y: a vector with shape [d]
"""
q = tf.einsum("d,hdk->hk", x, P_q)
K = tf.einsum("md,hdk->hmk", M, P_k)
V = tf.einsum("md,hdv->hmv", M, P_v)
logits = tf.einsum("hk,hmk->hm", q, K)
weights = tf.softmax(logits)
o = tf.einsum("hm,hmv->hv", weights, V)
y = tf.einsum("hv,hdv->d", o, P_o)
return y
关于缩放因子的说明:【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】在 logits 上包含了一个恒定的缩放因子。本文在代码中省略了它,因为它可以被整合到线性投影 Pq 或 Pk 中。
两种批处理方式:在实践中,将多个查询批处理在一起效率更高。下面的代码增加了两种批处理方式。首先,我们从一个序列的 n 个不同位置生成查询,这些查询都与相同的键和值进行交互。此外,我们一次处理 b 个不同的非交互序列。遵循【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】的做法,在一个自回归模型中,我们可以通过向 logits 添加一个在非法位置值为 -∞ 的“掩码”(mask)来防止信息的反向流动。
def MultiheadAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
""" Multi-head Attention.
Args:
X: a tensor with shape [b, n, d]
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns:
Y: a tensor with shape [b, n, d]
"""
Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
K = tf.einsum("bmd,hdk->bhmk", M, P_k)
V = tf.einsum("bmd,hdv->bhmv", M, P_v)
logits = tf.einsum("bhnk,bhmk->bhnm", Q, K)
weights = tf.softmax(logits + mask)
O = tf.einsum("bhnm,bhmv->bhnv", weights, V)
Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
return Y
性能分析:为了简化性能分析,本文做出以下几点假设:$m = n$;$k = v = d/h$,如【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】所建议;$n \le d$。总的算术操作数量为 $\Theta(bnd^2)$(因为在简化假设下,上述每个 tf.einsum 操作的复杂度为 $O(bnd^2)$)。需要访问的内存总量等于所有相关张量大小的总和:$O(bnd + bhn^2 + d^2)$。其中第一项来自 X, M, Q, K, V, O 和 Y,第二项来自 logits 和 weights,第三项来自投影张量 Pq, Pk, Pv 和 Po。将两者相除,我们发现内存访问与算术操作的比率为 $O(1/k + 1/(bn))$。这种低比率对于在现代 GPU/TPU 硬件上获得良好性能是必要的,因为这些硬件的计算能力可能比内存带宽高出两个数量级。
增量计算的必要性:在某些情况下,数据依赖性使得无法并行处理来自多个位置的查询。一个例子是自回归语言模型(如 Transformer)中的自注意力层。在每个位置产生的查询会关注到截至该位置(包括该位置)的所有位置产生的键值对。在训练期间,由于真实的(ground-truth)目标序列是已知的,我们可以使用类似于 2.3 节中的高效并行实现。然而,当从训练好的模型生成序列时,特定位置的自注意力层输出会影响下一位置生成的 token,而这个 token 又会影响下一位置该层的输入。这阻止了并行计算。
def MultiheadSelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
""" Multi-head Self-Attention (one step).
Args:
x: a tensor with shape [b, d]
prev_K: tensor with shape [b, h, m, k]
prev_V: tensor with shape [b, h, m, v]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns:
y: a tensor with shape [b, d]
new_K: tensor with shape [b, h, m+1, k]
new_V: tensor with shape [b, h, m+1, v]
"""
q = tf.einsum("bd,hdk->bhk", x, P_q)
new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum("bd,hdk->bhk", M, P_k), axis=2)], axis=2)
new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd,hdv->bhv", M, P_v), axis=2)], axis=2)
logits = tf.einsum("bhk,bhmk->bhm", q, new_K)
weights = tf.softmax(logits)
o = tf.einsum("bhm,bhmv->bhv", weights, new_V)
y = tf.einsum("bhv,hdv->bd", O, P_o)
return y, new_K, new_V
性能分析与瓶颈:本文采用与 2.3.1 节相同的简化假设。在 n 次调用中,总算术操作数为 $\Theta(bnd^2)$。总内存访问量为 $\Theta(bn^2d + nd^2)$,第一项来自 K 和 V,第二项来自 Pq, Pk, Pv 和 Po。将内存除以计算量,我们发现内存访问与算术操作的比率为 $\Theta(n/d + d/(bn))$。当 $n \approx d$ 或 $b \approx 1$ 时,该比率接近 1,导致内存带宽成为现代计算硬件上的主要性能瓶颈。为了使增量生成高效,我们必须将这两个项都减少到远小于 1。$d/(bn)$ 项比较容易处理,只需在内存允许的情况下使用更大的批量大小即可。减少 $n/d$ 项则更难,该项与每一步重新加载代表内存的 K 和 V 张量(大小为 $bhmk = bn d k/d = bn^2$)的开销有关。一个解决方案是限制序列长度 n。另一个是减少被关注的位置数量,可以通过关注局部邻域,或如 【Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer. Generating wikipedia by summarizing long sequences. In Proceedings of the International Conference on Learning Representations, 2018】、【Biao Zhang, Deyi Xiong, and Jinsong Su. Accelerating neural transformer via an average attention network, 2018】、【Daniel Povey, Hossein Hadian, Pegah Ghahremani, Ke Li, and Sanjeev Khudanpur. A time-restricted selfattention layer for ASR. In Proceddings of the IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018】 中那样压缩内存位置的数量。本文提出了一种正交的方法来减小 K 和 V 张量的大小,即移除它们的“头”维度,同时在查询中保留“头”维度。
多查询注意力的核心设计:我们引入多查询注意力作为【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】中描述的多头注意力的一种变体。多头注意力由多个并行的注意力层(头)组成,对查询、键、值和输出进行不同的线性变换。多查询注意力与此相同,唯一的区别是不同的头共享同一组键(keys)和值(values)。
代码实现:(增量式)多查询(自)注意力的代码与上面列出的多头注意力的代码相同,只是我们从 tf.einsum 方程中代表 K、V、Pk 或 Pv 的“头”维度的字母 "h" 移除了。
def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
""" Multi-Query Attention.
Args:
X: a tensor with shape [b, n, d]
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [d, k]
P_v: a tensor with shape [d, v]
P_o: a tensor with shape [h, d, v]
Returns:
Y: a tensor with shape [b, n, d]
"""
Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
K = tf.einsum("bmd,dk->bmk", M, P_k)
V = tf.einsum("bmd,dv->bmv", M, P_v)
logits = tf.einsum("bhnk,bmk->bhnm", Q, K)
weights = tf.softmax(logits + mask)
O = tf.einsum("bhnm,bmv->bhnv", weights, V)
Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
return Y
def MultiquerySelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
""" Multi-query Self-Attention (one step).
Args:
x: a tensor with shape [b, d]
prev_K: tensor with shape [b, m, k]
prev_V: tensor with shape [b, m, v]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [d, k]
P_v: a tensor with shape [d, v]
P_o: a tensor with shape [h, d, v]
Returns:
y: a tensor with shape [b, d]
new_K: tensor with shape [b, m+1, k]
new_V: tensor with shape [b, m+1, v]
"""
q = tf.einsum("bd,hdk->bhk", x, P_q)
K = tf.concat([prev_K, tf.expand_dims(tf.einsum("bd,dk->bk", M, P_k), axis=2)], axis=2)
V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd,dv->bv", M, P_v), axis=2)], axis=2)
logits = tf.einsum("bhk,bmk->bhm", q, K)
weights = tf.softmax(logits)
o = tf.einsum("bhm,bmv->bhv", weights, V)
y = tf.einsum("bhv,hdv->bd", O, P_o)
return y, K, V
性能分析:我们采用与 2.3.1 节相同的简化假设。在 n 次调用中,总算术操作数仍为 $\Theta(bnd^2)$。总内存访问量为 $\Theta(bnd + bn^2k + nd^2)$,第一项来自 x, q, o 和 y,第二项来自 K 和 V,第三项来自 Pq, Pk, Pv, Po。将内存除以计算量,我们发现内存访问与算术操作的比率为 $\Theta(1/d + n/(dh) + d/(bn))$。
性能提升:我们已将有问题的 $n/d$ 项减小了 h 倍。理论上,在给定大批量 b 的情况下,这应能显著提高增量生成的性能。在实验部分,我们将展示性能提升是真实的,并且模型质量保持在高水平。
h 或键/值维度 k 和 v 来减小 K 和 V 大小的模型,并同样扩大前馈隐藏层以匹配参数量。tensor2tensor 和 mesh-tensorflow 库的实现。h、d_k 和 d_v 的替代方案。
本文提出了多查询注意力,作为多头注意力的替代方案,它在增量推理场景下具有低得多的内存带宽需求。我们相信,这将使得基于注意力的序列模型能够在对推理性能要求严苛的应用中得到更广泛的采用。