Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free

文章标题:门控注意力在大型语言模型中的应用:非线性、稀疏性与无注意力沉溺
作者/机构:Zihan Qiu∗1, Zekun Wang∗1, Bo Zheng∗1, Zeyu Huang∗2, Kaiyue Wen3, Songlin Yang4, Rui Men1, Le Yu1, Fei Huang1, Suozhi Huang5, Dayiheng LiuB1, Jingren Zhou1, Junyang LinB1
1 Qwen Team, Alibaba Group 2 University of Edinburgh 3 Stanford University 4 MIT 5 Tsinghua University

A1 主要贡献

本文系统性地研究了在标准 Softmax 注意力机制中引入门控(gating)机制的影响。

图 1:左:研究中应用门控操作的位置。中:在不同位置应用门控的 15B MoE 模型的性能比较(测试 PPL 和 MMLU)。在 SDPA 之后(G1)应用门控取得了最佳的综合效果。在 Value 层之后(G2)应用门控也显示出显著的改进,尤其是在 PPL 方面。右:在相同超参数下,基线模型和 SDPA 门控的 1.7B 密集模型在 3.5T token 上的训练损失比较(平滑系数 0.9)。门控带来了更低的最终损失和显著增强的训练稳定性,减少了损失尖峰。这种稳定性使得模型可能使用更高的学习率,并有助于更好的扩展。
图 1:左:研究中应用门控操作的位置。中:在不同位置应用门控的 15B MoE 模型的性能比较(测试 PPL 和 MMLU)。在 SDPA 之后(G1)应用门控取得了最佳的综合效果。在 Value 层之后(G2)应用门控也显示出显著的改进,尤其是在 PPL 方面。右:在相同超参数下,基线模型和 SDPA 门控的 1.7B 密集模型在 3.5T token 上的训练损失比较(平滑系数 0.9)。门控带来了更低的最终损失和显著增强的训练稳定性,减少了损失尖峰。这种稳定性使得模型可能使用更高的学习率,并有助于更好的扩展。

A3 背景知识

多头 Softmax 注意力机制回顾

给定输入 $X \in \mathbb{R}^{n \times d_{model}}$,其中 $n$ 是序列长度,$d_{model}$ 是模型维度,Transformer 注意力层的计算【索引 47,Attention is all you need,A Vaswani,2017,Advances in Neural Information Processing Systems】可分为四个阶段。

A2 方法细节

使用门控机制增强注意力层

表 1:门控变体性能与结果。我们在 400B token 上训练了 15A2B MoE 模型。dk 是头维度,dmodel 是模型的隐藏维度,n 是 token 数量。q 指的是查询头的数量,k 指的是键值头的数量。‘Act Func’ 是公式 5 中的激活函数。‘Score Shape’ 是输入 X ∈ Rn,dmodel 的门控分数形状。‘added param’ 表示增加的参数(百万)。
表 1:门控变体性能与结果。我们在 400B token 上训练了 15A2B MoE 模型。dk 是头维度,dmodel 是模型的隐藏维度,n 是 token 数量。q 指的是查询头的数量,k 指的是键值头的数量。‘Act Func’ 是公式 5 中的激活函数。‘Score Shape’ 是输入 X ∈ Rn,dmodel 的门控分数形状。‘added param’ 表示增加的参数(百万)。

A4 实验环境

A4 实验结果

主要结果

门控注意力在 MoE 模型中的表现

门控注意力在密集模型中的表现

表 2:不同方法在不同学习率、批大小和模型配置下的性能表现。'SDPA' 指的是在公式 3 的 SDPA 之后应用 sigmoid 门控,'sandwitch norm'【索引 16】表示在将 attention/ffn 输出添加到残差连接之前对其进行归一化。使用门控时,我们减小了 FFN 的宽度,以使所有方法的参数数量相同。'-' 表示模型在训练过程中发散。
表 2:不同方法在不同学习率、批大小和模型配置下的性能表现。'SDPA' 指的是在公式 3 的 SDPA 之后应用 sigmoid 门控,'sandwitch norm'【索引 16】表示在将 attention/ffn 输出添加到残差连接之前对其进行归一化。使用门控时,我们减小了 FFN 的宽度,以使所有方法的参数数量相同。'-' 表示模型在训练过程中发散。

分析:非线性、稀疏性与无注意力沉溺

非线性提升了注意力中低秩映射的表达能力

表 3:不同(非)线性增强方法的性能。
表 3:不同(非)线性增强方法的性能。

门控引入了输入依赖的稀疏性

图 3:SDPA 逐元素(左)、Value 逐元素(中)以及 SDPA 逐元素头共享门控(右)的门控分数均值和分布。大多数门控分数小于 0.5,表明门控分数是稀疏的。其中,SDPA 输出门控分数表现出最强的稀疏性。
图 3:SDPA 逐元素(左)、Value 逐元素(中)以及 SDPA 逐元素头共享门控(右)的门控分数均值和分布。大多数门控分数小于 0.5,表明门控分数是稀疏的。其中,SDPA 输出门控分数表现出最强的稀疏性。

SDPA 输出门控减少了注意力沉溺

SDPA 输出门控促进了上下文长度扩展

表 5:不同方法在不同序列长度下的性能表现。‘YaRN Extended’ 表示扩展上下文长度的变体。‘(values)’ 表示扩展上下文长度后性能下降的值。
表 5:不同方法在不同序列长度下的性能表现。‘YaRN Extended’ 表示扩展上下文长度的变体。‘(values)’ 表示扩展上下文长度后性能下降的值。

A7 补充细节

相关工作

A5 结论

局限性

A6 附录

A.1 Switch Head 基线实验

表 6:不同 switch head 方法在不同参数增加和配置下的性能。‘switch kv’ 和 ‘switch v’ 分别指在键值和值组件中引入选择性计算。‘Switch kv, 8top8’ 意味着有 8 个键和值映射专家,每个 token 选择 top8 专家。注意‘Switch v, 1top1’ 等同于表 1 行 (11) 中的 v Headwise Gate。
表 6:不同 switch head 方法在不同参数增加和配置下的性能。‘switch kv’ 和 ‘switch v’ 分别指在键值和值组件中引入选择性计算。‘Switch kv, 8top8’ 意味着有 8 个键和值映射专家,每个 token 选择 top8 专家。注意‘Switch v, 1top1’ 等同于表 1 行 (11) 中的 v Headwise Gate。

A.2 关于稀疏门控分数的更多讨论

图 4:门控前后的平均绝对值。基线和门控后的值相似。
图 4:门控前后的平均绝对值。基线和门控后的值相似。

图 5:门控后低于阈值的 SDPA 输出值比例(左:1e-2,右:1e-3)。我们还包括了通过将平均门控分数与门控前隐藏状态相乘得到的稀疏性度量。
图 5:门控后低于阈值的 SDPA 输出值比例(左:1e-2,右:1e-3)。我们还包括了通过将平均门控分数与门控前隐藏状态相乘得到的稀疏性度量。

A.3 逐层的巨幅激活和注意力沉溺

图 6:不同门控配置下巨幅激活和注意力沉溺现象的比较。第 1 行(基线):第 6 层后出现显著的巨幅激活和注意力沉溺。第 2 行(SDPA 门控):激活减少,未观察到注意力沉溺。第 3 行(Value 层门控):激活与第 2 行相似,但存在残余的注意力沉溺。第 4-5 行(通过跨头共享和 NS-sigmoid 减少稀疏性):巨幅激活和注意力沉溺与基线相似。
图 6:不同门控配置下巨幅激活和注意力沉溺现象的比较。第 1 行(基线):第 6 层后出现显著的巨幅激活和注意力沉溺。第 2 行(SDPA 门控):激活减少,未观察到注意力沉溺。第 3 行(Value 层门控):激活与第 2 行相似,但存在残余的注意力沉溺。第 4-5 行(通过跨头共享和 NS-sigmoid 减少稀疏性):巨幅激活和注意力沉溺与基线相似。

A.4 更多逐层门控分数分析

图 7:SDPA 输出门控变体在不同约束下门控分数的分布。
图 7:SDPA 输出门控变体在不同约束下门控分数的分布。

A.5 稳定训练的其他尝试