MHLA: Restoring Expressivity of Linear Attention via Token-Level Multi-Head

作者/机构: Kewei Zhang, Ye Huang, Yufan Deng (北京大学); Jincheng Yu, Junsong Chen, Huan Ling, Enze Xie, Daquan Zhou (NVIDIA)

A1 主要贡献

本文旨在解决 Transformer 架构中自注意力机制的二次方复杂度问题,该问题严重限制了其在大规模应用(如高分辨率图像和视频生成)中的可扩展性。线性注意力作为一种高效的替代方案,虽然将复杂度降至线性,但通常以牺牲模型性能为代价,因为它将所有键(keys)和值(values)压缩成一个全局摘要,从而失去了 softmax 注意力为每个查询(query)单独适应上下文的能力。现有的一些修复方法,如引入深度可分离卷积(depthwise separable convolution),又会重新带来计算开销,违背了初衷。

本文通过深入分析,指出现有线性注意力方法的一个关键失败模式是全局上下文坍塌(global context collapse)。具体来说,当所有 tokens 被压缩成一个单一的全局键值摘要(KV summary)供所有查询共享时,模型的表征多样性会显著降低。作者通过评估注意力权重矩阵的秩(rank)来量化这种多样性,发现共享的全局 KV 摘要将模型的表征能力限制在一个固定的秩上,随着序列变长,注意力权重趋于均匀分布,从而降低了模型在需要关注特定子集任务上的性能。

为了在不牺牲线性时间复杂度和不引入繁重辅助模块的前提下,恢复查询依赖的多样性,本文提出了多头线性注意力(Multi-Head Linear Attention, MHLA)。MHLA 的核心思想是沿 token 维度将序列划分为多个不重叠的块(在空间维度上称为“头”),在每个块内计算局部的键值摘要。然后,每个查询块通过一个查询条件下的混合机制(Multi-Head Mixing),对这些局部摘要进行加权组合,从而检索到一个量身定制的上下文。在此基础上,通过一个查询依赖的重加权模块进一步细化所选块内 token 的贡献。这种设计仅依赖于标准的通用矩阵乘法(GEMM)操作,保持了 O(N) 的线性复杂度,并与流式/有状态执行兼容。

图1 (a) 使用MHLA微调的SANA模型的生成结果。(b) 所提出的MHLA与基线的性能和效率比较。吞吐量在NVIDIA H100 Tensor Core GPU上测试。按照先前的方法,我们在表格中报告了256 × 256分辨率下的FID。(c) MHLA的多领域性能。我们在不同领域评估MHLA,展示其强大和通用的性能。(d) DiT-S/2在4096分辨率下在不同设备上的吞吐量。所有改进仅归功于MHLA,并可与正交技术进一步结合以实现更大的加速。
图1 (a) 使用MHLA微调的SANA模型的生成结果。(b) 所提出的MHLA与基线的性能和效率比较。吞吐量在NVIDIA H100 Tensor Core GPU上测试。按照先前的方法,我们在表格中报告了256 × 256分辨率下的FID。(c) MHLA的多领域性能。我们在不同领域评估MHLA,展示其强大和通用的性能。(d) DiT-S/2在4096分辨率下在不同设备上的吞吐量。所有改进仅归功于MHLA,并可与正交技术进一步结合以实现更大的加速。
图2 所提出的MHLA与其他线性注意力的比较。MHLA在token维度上划分多个头。通过多头混合,MHLA通过以查询特定的权重混合KV摘要来恢复查询条件的自适应性,在保持线性复杂度的同时提高了token级别的多样性。
图2 所提出的MHLA与其他线性注意力的比较。MHLA在token维度上划分多个头。通过多头混合,MHLA通过以查询特定的权重混合KV摘要来恢复查询条件的自适应性,在保持线性复杂度的同时提高了token级别的多样性。

本文的主要贡献总结如下:

A3 背景知识与关键观察

3.1 预备知识

注意力权重的通用计算公式。首先,本文对自注意力和线性注意力机制的权重计算进行公式化表述。给定一个输入 token 序列 $X \in R^{N \times d}$,我们首先通过可学习的投影矩阵 $W_Q, W_K, W_V \in R^{d \times d}$ 计算查询(queries)、键(keys)和值(values),即 $Q = XW_Q, K = XW_K, V = XW_V$。那么,token i 的注意力输出可以表示为:

$$Y_{i}=\frac{\sum_{j=1}^{N}\operatorname{Sim}(Q_{i},K_{j})V_{j}}{\sum_{m=1}^{N}\operatorname{Sim}(Q_{i},K_{m})},$$

其中 Sim(·, ·) 用于计算输入矩阵之间的相似度。在 softmax attention 中(【49,Attention is all you need,NeurIPS 2017】),$Sim(Q_i, K_j) = exp(Q_i K_j^T / \sqrt{d})$,需要计算所有成对的相似度并为每个查询进行归一化,这导致了 $O(N^2)$ 的复杂度。

线性注意力的简化计算。线性注意力将指数核替换为一个正特征图 $\phi(\cdot)$,使得相似度计算可以近似分解。这样,输出 $Y_i$ 的计算可以重排,其分子和分母可以预先计算为一个全局键值摘要 $G = \sum_j \phi(K_j)^T V_j$ 和归一化项 $z = \sum_m \phi(K_m)^T$。这种方式将计算复杂度从 $O(N^2)$ 降低到 $O(Nd_{\phi})$,从而实现了随序列长度线性扩展。

$$\text{Sim}(Q_i, K_j) \approx \phi(Q_i)\phi(K_j)^\top, \quad Y_i = \frac{\phi(Q_i)(\sum_{j=1}^N \phi(K_j)^\top V_j)}{\phi(Q_i)(\sum_{m=1}^N \phi(K_m)^\top)},$$

3.2 全局上下文坍塌

线性注意力的信息瓶颈。线性注意力通过在所有查询中重用一个全局键值摘要 $G = \sum_{j=1}^{N} \phi(K_j)^T V_j \in R^{d \times d}$ 来实现线性时间复杂度。但是,这种固定大小的设计引入了一个内在的信息瓶颈。

观察:全局上下文坍塌。随着序列长度 N 的增加,需要表示的信息超出了固定大小的 $d \times d$ 矩阵的容量,导致性能饱和。我们将这种现象称为全局上下文坍塌

量化指标:秩与稀疏性。这个观察可以用两个互补的指标来量化:注意力矩阵的秩和稀疏性。

秩限制。注意力矩阵的秩被广泛研究为注意力机制中特征多样性和表征能力的关键指标(【3,Low-rank bottleneck in multi-head attention models,2020】;【22,Breaking the low-rank dilemma of linear attention,CVPR 2025】;【24,Flatten transformer: Vision transformer using focused linear attention,ICCV 2023】)。具体来说,设 $\tilde{Q} = \phi(Q)$ 和 $\tilde{K} = \phi(K)$,全局线性注意力的注意力矩阵 $A_{lin}$ 的秩受到严格限制。

$$A_{\operatorname{lin}}=\widetilde{Q} \widetilde{K}^{\top} \in \mathbb{R}^{n \times n}, \quad \operatorname{rank}\left(A_{\operatorname{lin}}\right) \leq \min \{\operatorname{rank}(\widetilde{Q}), \operatorname{rank}(\widetilde{K})\} \leq d.$$

结论1:表征能力的严格上限。无论序列长度 N 为何,线性注意力矩阵 $A_{lin}$ 的表征能力严格受限于维度 d。尽管之前有研究尝试增加键值摘要的秩(【5,Saga: Selective adaptive gating for efficient and expressive linear attention,2025】;【22,Breaking the low-rank dilemma of linear attention,CVPR 2025】),但当 $n \gg d$ 时,这个界限导致了对完整 $n \times n$ 注意力矩阵的严重秩亏近似,从而限制了模型捕捉多样化的、查询条件的注意力模式的能力。

秩限制的经验验证。我们在图 3b 中通过经验验证了这一效应,该图显示了基于线性注意力的模型中,注意力分数的秩始终被头维度(通常 $d_h \leq 72$)所限制,并且随着序列长度的增加,注意力图的相对表达能力会下降。

稀疏性的丧失。注意力矩阵的稀疏性是影响注意力机制性能的关键因素。稀疏分布通常表现出较低的熵,将概率质量集中在少数信息丰富的 token 上(【15,Superiority of softmax: Unveiling the performance edge over linear attention,2023】;【57,Attention entropy is a key factor...,2025】),这有利于模型优化。然而,线性注意力首先将所有键值对压缩成一个单一的全局摘要,每个查询只与这个共享的表示交互一次。相比之下,softmax 注意力利用指数函数使每个查询 $q_i$ 能够产生一个独特的 token 分布(见附录 B)。由于线性注意力对所有查询都依赖于相同的聚合表示,它无法根据查询特定的相关性来重新加权单个键。

结论2:注意力权重趋于均匀。随着序列长度 N 的增加,每个 token 的贡献变得微不足道。因此,注意力权重分布趋于均匀,降低了稀疏性,并削弱了模型选择性地强调信息丰富 token 的能力。

熵的量化效应。为了量化这种效应,我们计算了 500 个随机样本上各注意力变体的平均熵。对于注意力分数矩阵的每一行,较低的熵表示分布更接近于独热向量,反映了更强的单 token 专注度。如图 3a 和图 3b 所示,线性注意力表现出显著更高的熵,证实了其与基于 softmax 的注意力相比缺乏焦点。

现象总结与研究动机。综合来看,这些发现揭示了线性注意力对单一全局键值摘要的依赖导致了表征能力的严重坍塌,表现为注意力图的秩亏和熵升高。我们将此现象称为全局上下文坍塌。图 3a 可视化了注意力分数和图,清晰地展示了线性注意力无法捕捉细粒度信息。这一观察促使我们开发能够在保持线性时间复杂度的同时,恢复查询条件的 token 级多样性的方法。

图3 (a) MHLA及基线的注意力分数和注意力图的可视化。(b) DeiT-T的注意力分数的平均秩和熵,显示MHLA产生更丰富和更集中的注意力。
图3 (a) MHLA及基线的注意力分数和注意力图的可视化。(b) DeiT-T的注意力分数的平均秩和熵,显示MHLA产生更丰富和更集中的注意力。

A2 方法细节

4.1 概述

MHLA的计算流程。本节我们正式介绍所提出的多头线性注意力(MHLA)。如图 4a 所示,MHLA 的操作方式是沿 token 维度将序列分割成多个“头”,并在这些“头”上并行运行线性注意力。设输入序列为 $X \in R^{N \times d}$,通过投影得到查询、键和值:$Q = XW_Q, K = XW_K, V = XW_V$,其中 $Q, K, V \in R^{N \times d}$。为了提高效率,我们采用核化公式,对于选定的特征图 $\phi(\cdot)$,记作 $\tilde{Q} = \phi(Q), \tilde{K} = \phi(K)$。

局部键值摘要的计算。标准的线性注意力将所有 tokens 聚合到一个所有查询共享的全局 $d \times d$ 摘要中,这通过压缩 token 级别的多样性降低了表达能力。为了缓解这个问题,我们将序列分成 M 个不重叠的块(MHLA 的“头”),其中块 b 包含 $N_b$ 个 tokens,且 $\sum_{b=1}^{M} N_b = N$。在视觉模型中,块通常在空间(2D)或时空(3D)网格上定义,而不是通过展平为 1D 来定义。对于每个块 b,我们计算一个局部的键值摘要及其归一化项:

$$S_b = \sum_{j \in b} \widetilde{K}_j V_j^\top \in \mathbb{R}^{d \times d}, \quad z_b = \sum_{j \in b} \widetilde{K}_j \in \mathbb{R}^d.$$

多头混合以恢复查询自适应性。为了恢复查询的自适应性,MHLA 通过多头混合(Multi-Head Mixing)为每个查询块 i 构建一个所有键值摘要的独特混合。块 i 中的查询可以关注这个混合体,其中不同的键值摘要根据当前查询块的注意力偏好进行加权。设 $m_i \in R^M$ 表示块 i 的非负、可学习的混合系数,这些系数在训练期间进行优化。混合后的摘要定义为 $\tilde{S}_i = \sum_{b=1}^{M} m_{i,b} S_b$,相应的归一化项为 $\tilde{z}_i = \sum_{b=1}^{M} m_{i,b} z_b$。

图4 (a) 所提出的多头线性注意力概述。(b) 我们可视化了当M为25时,分别对应于块1和块14的初始化可学习系数矩阵的两行。为了更好地理解,我们将这两行和M维度重塑为2D。
图4 (a) 所提出的多头线性注意力概述。(b) 我们可视化了当M为25时,分别对应于块1和块14的初始化可学习系数矩阵的两行。为了更好地理解,我们将这两行和M维度重塑为2D。

输出计算。这个过程可以通过键值摘要与由 $m_i$ 组成的系数矩阵 $M_c \in R^{M \times M}$ 之间进行高度硬件高效的 GEMM 操作来完成。给定一个来自块 i 的查询向量 $\tilde{q} \in R^d$,输出为:

$$o = \frac{\widetilde{q}^{\top} \widetilde{S}_{i}}{\widetilde{q}^{\top} \widetilde{z}_{i}} = \frac{\sum_{b=1}^{M} m_{i, b} \widetilde{q}^{\top} S_{b}}{\sum_{b=1}^{M} m_{i, b} \widetilde{q}^{\top} z_{b}}.$$

输出的解释。因此,每个输出元素可以被解释为整个值序列的查询特定、块依赖的重组。在语言建模和视频生成等任务中,当序列变长时,为了更好的训练稳定性,可以省略归一化项(【39,The devil in linear transformer,2022】)。

4.2 多头混合

学习系数矩阵的作用。MHLA 自适应性的核心是一个学习到的系数矩阵 $M_c \in R^{M \times M}$。位于 $(i, j)$ 位置的元素表示查询块 i 与块 j 的局部键值摘要之间的亲和度。等价地,Mc 的第 i 行,表示为 $m_i$,指定了查询块 i 如何将 M 个局部摘要线性组合成一个查询特定的全局摘要。

系数矩阵的初始化与学习。每一行 $m_i$ 都是端到端生成和学习的;在实践中,我们强制其非负性和归一化。由于块是沿着空间或时空轴定义的,我们将 $M_c$ 初始化为偏好局部性:对于第 i 行,我们将初始系数设置为 $m_{i,j}^{(0)} \propto 1 - \text{dist}(i, j) / \max_k(\text{dist}(i, k))$,其中 $\text{dist}(i, j)$ 衡量欧几里得距离,$\max_k \text{dist}(i, k)$ 是从 i 到任何位置 k 的最大距离。然后对系数进行归一化,使得 $\sum_j m_{i,j}^{(0)} = 1$。这种初始化的可视化可以在图 4b 中找到。这种偏好局部性的初始化产生了更稳定和更快的收敛,同时让 $M_c$ 在训练期间自由适应。为了进一步确保稳定性,我们在每次更新时将系数裁剪到区间 (0, 1) 内。

多头混合在 token 级别的效果。多头混合在 token 级别的效果是显而易见的。设 $b(t)$ 表示 token t 的块索引。将每个局部摘要写成其 token 的和,$G_j = \sum_{t \in \text{block } j} \tilde{K}_t V_t^T$,查询块 i 的混合摘要展开为:

$$\widetilde{S}_{i}=\sum_{j=1}^{M} m_{i, j} S_{j}=\sum_{t=1}^{N} m_{i, b(t)} \widetilde{K}_{t} V_{t}^{\top} \in \mathbb{R}^{d \times d} .$$

两阶段加权机制。对于一个查询向量 $\tilde{q} = \phi(q)$(来自块 i),核化更新的分子变为:

$$\widetilde{q}^{\top} \widetilde{S}_{i}=\sum_{t=1}^{N} m_{i, b(t)}\left(\widetilde{q}^{\top} \widetilde{K}_{t}\right) V_{t}^{\top} \in \mathbb{R}^{d}.$$

公式 5 使机制变得透明:每个查询块通过 $m_i$ 重新缩放整个块的贡献,并且在每个块内,通常的核内积 $\tilde{q}^T \tilde{K}_t$ 区分了不同的 tokens。因此,MHLA 以一种两阶段的方式(块选择 × 块内重加权)恢复了查询条件的、token 级别的加权。重要的是,所有操作都简化为块级摘要计算和 M 个大小为 $d \times d$ 的矩阵的线性组合,因此渐进复杂度在 N 上保持线性,而表达能力则大幅增加。

MHLA 的分块并行形式。线性注意力通常采用分块并行训练(【29,Transformer quality in linear time,ICML 2022】;【46,Retentive network: A successor to Transformer for large language models,2023】)来在因果掩码下保持线性时间复杂度,方法是将序列划分为块并为每个块更新一个运行摘要。MHLA 自然地适应了这种设置:每个头可以直接映射到一个块,我们为每个块维护一个局部摘要 $S_b$。在训练时,我们使用学习到的混合系数 $m_{i,b}$ 来聚合这些局部摘要,形成混合前缀摘要 $\tilde{S}_i = \sum_{b \le i} m_{i,b} S_b$,然后用于块级注意力。由于混合计算每个块执行一次并被该块中的所有查询重用,因此总体复杂度与分块线性注意力相同。详细的推导和相应的推理过程见附录 C。

4.3 多头线性注意力的分析

秩分析。我们将序列划分为 M 个大小为 $N_b$ 的不重叠块。设查询矩阵为 $\tilde{Q} = [\tilde{Q}_1^T, \dots, \tilde{Q}_M^T]^T$,其中 $\tilde{Q}_b \in R^{n_b \times d}$。根据公式 5,在计算注意力分数时,查询块 i 所见的混合键序列可以表示为:

$$Y_{i}=\left[m_{i, b(1)} k_{1}, m_{i, b(2)} k_{2}, \ldots, m_{i, b(n)} k_{n}\right] \in \mathbb{R}^{d \times n},$$

其中 $m_{i,b(t)}$ 是选择 token t 所在块的混合系数。由查询块 i 贡献的注意力子矩阵是 $A_i = \tilde{Q}_i Y_i \in R^{N_b \times N}$,完整的注意力矩阵是 $A_{MHLA} = [A_1^T, A_2^T, \dots, A_M^T]^T \in R^{n \times n}$。然后应用标准秩不等式可得:

$$\operatorname{rank}\left(A_{b}\right) \leq \min \left\{\operatorname{rank}\left(\widetilde{Q}_{b}\right), \operatorname{rank}\left(Y_{b}\right)\right\} \leq \min \left(n_{b}, d\right),$$

全局秩的上界。这得出了全局上界 $\text{rank}(A_{MHLA}) \le \min(n, \sum_{b=1}^{M} \min(n_b, d))$。在温和、通用的条件下,这个上界是可以达到的:如果每个块的乘积 $\tilde{Q}_b Y_b$ 具有满行秩 $\tilde{r}_b = \min(n_b, d)$,并且 $\{\tilde{Q}_b Y_b\}_{b=1}^{M}$ 的行空间是线性无关的,那么我们得到 $\text{rank}(A_{MHLA}) = \min(n, \sum_{b=1}^{M} r_b)$。即使线性无关的假设不完全满足,块级混合仍然扩展了行空间的多样性,导致 $\text{rank}(A_{MHLA})$ 随着 M 近似加性增长。我们在图 3b 中通过经验验证了这一行为,其中 MHLA 始终比其他线性注意力变体实现显着更高的注意力分数秩——并且不依赖于深度卷积等辅助组件。这证实了 MHLA 内在地恢复了在全局线性注意力中丢失的大部分表征能力,后者的秩无论序列长度 N 如何都严格受限于 d。

稀疏性分析。学习到的系数矩阵 $M_c$ 允许每个查询块为更相关的块子集分配更高的权重,有效地在块级别修剪不相关的 tokens。在每个选定的块内,核内积 $\tilde{q}^T \tilde{K}_t$ 进一步区分了 token 的贡献,导致更尖锐和更集中的注意力分布。我们在图 3b 中通过经验验证了这一效果,其中 MHLA 始终比其他线性注意力基线甚至 softmax 注意力产生更低的注意力熵。这证实了 MHLA 保留了查询条件的自适应性并实现了显着更高的稀疏性,使模型能够关注一个小的、语义相关的 token 子集,而不是均匀地分散注意力。

不同注意力机制的比较。表1总结了自注意力、线性注意力和MHLA在计算复杂度、最大可达秩、内存复杂度和查询条件自适应性方面的比较。

表1 自注意力、线性注意力和MHLA之间的比较。我们报告了计算复杂度、最大可达秩、内存复杂度和查询条件的自适应性。
表1 自注意力、线性注意力和MHLA之间的比较。我们报告了计算复杂度、最大可达秩、内存复杂度和查询条件的自适应性。

效率分析。MHLA 的计算包括局部键值摘要计算、多头混合和输出计算,时间复杂度为 $O(M N_b d^2 + M^2 d^2 + M N_b d^2) = O(N d^2 + M^2 d^2)$。为了在保证效率的同时更好地捕捉局部信息,块的数量 M 通常设置为满足 $M^2 \le N$。因此,$N d^2$ 成为主导项,MHLA 的时间复杂度为 $O(N d^2)$。自注意力、线性注意力和 MHLA 的比较总结在表 1 中。我们还在附录 F.4 中提供了 N 和 M 之间缩放关系的经验分析,验证了所推导的复杂度。

A4 实验环境

A5 实验结果

5.1 图像分类

我们在 DeiT 和 VLT 两种代表性架构上集成了 MHLA,并在 ImageNet-1K 上进行了训练。

表2 图像分类任务的比较。MHLA在DeiT模型上以最小的参数开销实现了最佳准确率,并优于基于Transformer、LA和Mamba的SOTA模型。标有*的结果是在与MHLA-VLT相同的训练设置下复现的。

(a) DeiT上不同注意力的比较。
 (a) DeiT上不同注意力的比较。

(b) 与ImageNet-1K上的SOTA模型比较。
 (b) 与ImageNet-1K上的SOTA模型比较。

5.2 图像生成

(a) 不同模型上注意力类型的比较。

表3 类到图像生成。在所有模型尺寸中,MHLA都取得了最佳性能。值得注意的是,在L和XL规模上,它在不依赖任何额外模块的情况下,达到了与自注意力相当的性能。
表3 类到图像生成。在所有模型尺寸中,MHLA都取得了最佳性能。值得注意的是,在L和XL规模上,它在不依赖任何额外模块的情况下,达到了与自注意力相当的性能。

表4 T2I模型比较。
表4 T2I模型比较。

图5 损失比较。
图5 损失比较。

5.3 视频生成

在处理极长序列的视频生成任务中,二次方注意力变得不可行。实验将预训练的 Wan2.1-1.3B 模型中的 FlashAttention 替换为 MHLA,并与替换为普通线性注意力(LA)的版本进行比较。序列长度达到 31,500 tokens。如表 5 所示,MHLA 的性能远超普通 LA,并与原始的 FlashAttention 模型相当,同时推理速度提升了 2.1 倍。普通 LA 因为全局上下文坍塌问题性能严重下降。图 6 的损失曲线也验证了这一点,普通 LA 的损失很快停滞在较高水平,而 MHLA 则能快速适应并收敛。

表5 MHLA在视频生成中的应用。Wan-FA表示预训练的Wan2.1-1.3B。Wan-MHLA和Wan-LA分别将所有层替换为MHLA和线性注意力。Wan-MHLA-H仅替换了2/3的层。
表5 MHLA在视频生成中的应用。Wan-FA表示预训练的Wan2.1-1.3B。Wan-MHLA和Wan-LA分别将所有层替换为MHLA和线性注意力。Wan-MHLA-H仅替换了2/3的层。

图6 Wan-2.1-1.3B上的损失比较。MHLA显示出更强的收敛能力。
图6 Wan-2.1-1.3B上的损失比较。MHLA显示出更强的收敛能力。

5.4 自然语言处理

表6 NLP中的MHLA。我们报告了在10B token上训练的模型的评估结果。我们突出显示了最好和次好的条目。
表6 NLP中的MHLA。我们报告了在10B token上训练的模型的评估结果。我们突出显示了最好和次好的条目。

表8 MHLA在LongBench上的表现。我们报告了在10B tokens上训练的340M模型的评估结果。我们突出显示了最好和次好的条目。
表8 MHLA在LongBench上的表现。我们报告了在10B tokens上训练的340M模型的评估结果。我们突出显示了最好和次好的条目。

5.5 消融研究

表7 所提出MHLA的消融研究。

(a) DeiT-T上初始化策略的消融实验。LB-init表示基于局部性的初始化。
(a) DeiT-T上初始化策略的消融实验。LB-init表示基于局部性的初始化。

A6 结论

本文提出了一种新颖的线性注意力机制,称为多头线性注意力(MHLA)。通过将 tokens 划分为多个组,MHLA 有效地保留了 token 级的多样性。在不依赖于深度卷积或混合自注意力层等额外模块的情况下,MHLA 实现了与基于自注意力的模型相当甚至更好的性能。作者设想这项工作可以建立一个基础性的注意力机制,从而惠及广泛的下游应用,例如高质量图像生成、长时程视频合成和大规模语言建模。

A7 附录

A 完整的相关工作

Transformer。自 Transformer 架构(【49,Attention is all you need,NeurIPS 2017】)被提出以来,自注意力已成为自然语言处理(【4,Language models are few-shot learners,NeurIPS 2020】;【17,Bert: Pre-training of deep bidirectional transformers for language understanding,NAACL 2019】)、计算机视觉(【18,An image is worth 16x16 words...,ICLR 2021】;【28,Coordinate attention for efficient mobile network design,CVPR 2021】;【33,Swin transformer...,ICCV 2021】;【58,Deepvit: Towards deeper vision transformer,2021】)和生成建模(【19,Taming transformers for high-resolution image synthesis,CVPR 2021】;【42,Photorealistic text-to-image diffusion models...,NeurIPS 2022】)等众多领域的主导机制。自注意力的表达能力源于其对所有 token 间成对交互的建模能力,但这带来了计算和内存上的二次方成本。这一限制在大规模或实时应用中尤为突出,从而推动了对更高效注意力机制的探索。人们提出了多种策略,如稀疏注意力(【2,Longformer...,2020】;【8,Generating long sequences with sparse transformers,NeurIPS 2019】;【56,Big bird...,NeurIPS 2020】)、低秩近似(【51,Linformer...,NeurIPS 2021】;【53,Nyströmformer...,AAAI 2021】)以及硬件优化的变体如 FlashAttention(【11,FlashAttention-2...,ICLR 2024】;【13,Flashattention...,NeurIPS 2022】)。尽管取得了这些进展,设计既能保持可扩展性又能保证准确性的高效注意力机制仍然是一个开放的挑战。

线性注意力。线性注意力已成为解决标准自注意力二次方复杂度的重要方向。早期工作用基于核的特征映射重新表述 softmax 操作,从而在训练和推理中实现线性时间复杂度(【9,Rethinking attention with performers,ICLR 2021】;【30,Transformers are rnns...,ICML 2020】;【36,Rwkv...,2023】;【37,Eagle and finch...,2024】;【54,Gated linear attention transformers...,2024】)。虽然这些方法使 Transformers 能够扩展到长序列,但与完整的 softmax 注意力相比,它们通常会降低表示能力,导致在视觉和生成建模等具有挑战性的任务中准确率下降。为了弥补这一差距,后续研究加入了额外的模块来丰富线性注意力的表达能力。例如,引入卷积层来捕捉局部上下文(【22,Breaking the low-rank dilemma of linear attention,CVPR 2025】;【24,Flatten transformer...,ICCV 2023】;【38,Random feature attention,ICLR 2021】;【44,Efficient attention...,WACV 2021】),提出门控机制来更好地控制信息流。最近,状态空间模型如 Mamba(【12,Transformers are SSMs...,ICML 2024】;【23,Mamba: Linear-time sequence modeling...,2023】)及其变体(【32,Vmamba...,2024】;【45,Multi-scale vmamba...,NeurIPS 2024】)也被探索作为线性注意力的有效替代方案,在长序列上显示出强大的可扩展性和有竞争力的准确性。然而,这些方法仍面临两个基本限制:(1)当以单向形式应用于需要双向注意力的任务时,它们表现出明显的性能下降;(2)当用额外模块(例如,卷积层或额外的自注意力块)增强时,它们不可避免地会产生更高的计算开销,并且仍然容易受到全局上下文坍塌的影响(见第 3.2 节),即全局摘要失去表示多样性。

稀疏注意力。除了线性注意力,稀疏注意力机制是解决 Transformer 计算瓶颈的另一个主要方法。像 Longformer(【2,Longformer: The long-document transformer,2020】)和 BigBird(【56,Big bird: Transformers for longer sequences,NeurIPS 2020】)这样的方法引入了稀疏的注意力模式,其中每个 token 只关注其他 token 的一个子集,从而减少了总的注意力操作数量。这些方法利用结构化稀疏性(例如,局部或全局注意力模式)来保持效率,同时仍然能在长序列中捕捉全局上下文。其他技术,如 Performer(【9,Rethinking attention with performers,ICLR 2021】),提出使用核近似来实现稀疏注意力,同时保持模型的表达能力。尽管稀疏注意力机制提高了可扩展性,但它们通常在准确性方面引入了权衡,尤其是在需要完整 token 交互的任务中。

线性与稀疏注意力的应用。线性和稀疏注意力机制已成功应用于包括自然语言处理、计算机视觉和生成建模在内的多个领域。在自然语言处理中,线性注意力已被用于将 BERT(【16,Bert: Pre-training of deep bidirectional transformers...,2018】)和 GPT(【40,Language models are unsupervised multitask learners,2019】)等模型扩展到更长的序列,从而能够更好地处理长文档并提高语言模型的效率(【4,Language models are few-shot learners,NeurIPS 2020】;【17,Bert: Pre-training of deep bidirectional transformers for language understanding,NAACL 2019】)。在计算机视觉中,线性注意力方法已被应用于视觉 Transformer,以提高处理大图像时的效率,如 Swin Transformer(【33,Swin transformer: Hierarchical vision transformer using shifted windows,ICCV 2021】)和 DeiT(【47,Training data-efficient image transformers & distillation through attention,ICML 2021】)等工作所示。这些应用展示了线性和稀疏注意力机制的广泛实用性,但也凸显了需要继续发展以平衡效率与复杂任务(如图像生成和视频理解)所需的表达能力。

B Softmax 注意力中的查询条件自适应性

Softmax注意力的关键优势。softmax 自注意力的一个关键优势是其查询条件的自适应性(query-conditioned selectivity)。回顾标准的注意力公式:

$$\operatorname{Attn}(Q, K, V)_i = \sum_{j=1}^N \alpha_{ij} v_j, \quad \alpha_{ij} = \frac{\exp(q_i^\top k_j)}{\sum_{t=1}^N \exp(q_i^\top k_t)}.$$

两个属性至关重要:(i)查询条件加权:每个查询 $q_i$ 产生自己的分布 $\{ \alpha_{ij} \}_{j=1}^N$,因此 token $k_j$ 的相对重要性完全依赖于 $q_i$;(ii)逐 token 加权:权重直接作用于每个 $v_j$,而不会将 V 压缩成一个全局摘要。这两个属性共同赋予 softmax 注意力产生高度自适应、 sharply 集中的上下文向量的能力。

全局线性注意力的局限。相比之下,全局线性注意力将所有 token 聚合到一个所有查询共享的单一摘要矩阵 $S_{global} = \sum_{j=1}^N \tilde{K}_j V_j^T$ 中,得到:

$$\text{Attn}_{\text{lin}}(Q, K, V)_i = \frac{\widetilde{q}_i^\top S^{\text{global}}}{\widetilde{q}_i^\top (\sum_{j=1}^N \widetilde{K}_j)},$$

这里,每个 token 的贡献不再能被 i 明确地分离开。结果是,不同的查询获得了几乎相同的上下文向量,失去了查询条件的自适应性。

MHLA如何恢复查询条件自适应性。MHLA 通过引入一个可学习的系数矩阵 $M_c$ 来弥合这一差距,该矩阵形成查询块特定的局部摘要混合:

$$\widetilde{S}_{i}=\sum_{b=1}^{M} m_{i, b} S_{b} \quad \Rightarrow \quad \operatorname{Attn}_{\mathrm{MHLA}}(Q, K, V)_{i}=\widetilde{q}_{i}^{\top} \widetilde{S}_{i} .$$

因为 $m_{i,b}$ 随查询块 i 变化,MHLA 根据查询块的不同,为同一个 token 分配不同的有效权重。将 $S_b$ 展开为其 token 级别的定义,得到:

$$\widetilde{q}_{i}^{\top} \widetilde{S}_{i}=\sum_{t=1}^{N} m_{i, b(t)}\left(\widetilde{q}_{i}^{\top} \widetilde{K}_{t}\right) V_{t}^{\top},$$

两阶段加权机制。这揭示了一个两阶段的加权机制:(i)块级选择 $m_{i,b(t)}$,这是查询条件的;紧接着是(ii)块内 token 重加权,通过核内积 $\tilde{q}_i^T \tilde{K}_t$ 实现。这种设计在保持核化注意力的线性时间复杂度的同时,重新引入了查询条件的自适应性和逐 token 加权。

C MHLA 的自回归建模

因果掩码下的线性注意力挑战。在自回归建模中,因果掩码阻止每个 token 关注未来的 tokens。虽然线性注意力通常通过重用全局键值摘要达到 $O(Nd^2)$ 的复杂度,但在因果掩码下,必须为每个前缀重新计算或更新摘要,这在朴素实现下会导致整个序列的成本为 $O(N^2 d)$。为了避免这种二次方开销,一个被广泛采用的线性注意力解决方案是分块并行训练(【46,Retentive network: A successor to Transformer for large language models,2023】),它将序列分成大小为 C 的块并并行处理它们,以避免重新计算所有过去 token 的注意力的二次方成本。对于块 b,计算一个局部键值摘要 $S_b = \sum_{j \in b} \tilde{K}_j V_j^T \in R^{d \times d}$,并递归更新全局摘要:

$$S_{i}^{\text {global }}=S_{i-1}^{\text {global }}+S_{i}, \quad H_{i}=Q_{i} S_{i-1}^{\text {global }}+(Q_{i} \widetilde{K}_{i}^{\top}) V_{i}.$$

在这里,第一项通过前缀摘要 $S_{global_{i-1}}$ 传播来自前面块的上下文,而第二项捕捉块内注意力。这种分块方案保持了因果性,并允许块并行训练,每个块的复杂度为 $O(Cd^2 + C^2d)$,对于长度为 L 的序列,总成本为 $O(\frac{L}{C}(Cd^2 + C^2d))$。

MHLA与分块并行训练的结合。MHLA 通过用查询条件的局部摘要混合替换单一的全局摘要来扩展此方案。具体来说,对于块 i,我们形成一个混合摘要:

$$\widetilde{S}_{i}=\sum_{b<i} m_{i, b} S_{b}, \quad H_{i}=Q_{i} \widetilde{S}_{i-1}+m_{i, b}(Q_{i} \widetilde{K}_{i}^{\top}) V_{i}.$$ <p>其中 $m_{i,b}$ 是来自因果系数矩阵 $M_{causal_c}$(上三角项被掩码以强制因果性)的可学习混合系数。块 i 中的查询然后仅与 $\tilde{S}_i$ 交互,从而产生块特定的、查询自适应的上下文表示,而不是一个共享的全局表示。因为混合是每个块执行一次并被该块中的所有 token 重用,所以渐进复杂度与分块线性注意力相匹配。

因果推理。在推理时,我们维护过去的局部摘要集合 $\{S_1, \dots, S_{i-1}\}$,并在新 token 到达时增量更新当前块摘要 $S_i$。当一个块完成后,其对未来混合的贡献被固定和缓存。对于块 i 中的一个新 token,我们只需更新 $S_i \leftarrow S_i + K_t V_t^T$ 并通过将 $m_{i,i}$ 应用于增量更新来重新计算该块的混合摘要 $\tilde{S}_i$。这避免了对先前块的重新计算,并保持每个 token 的复杂度为 $O(d^2)$。

D 数据集

多任务验证。为了评估我们方法的有效性,我们在四个任务上进行了广泛的实验:图像分类、类到图像(C2I)生成、文本到图像(T2I)生成和自然语言处理。遵循先前的工作(【21,Rectifying magnitude neglect in linear attention,ICCV 2025】;【22,Breaking the low-rank dilemma of linear attention,CVPR 2025】;【24,Flatten transformer: Vision transformer using focused linear attention,ICCV 2023】),我们在 ImageNet-1K(【14,Imagenet: A large-scale hierarchical image database,CVPR 2009】)上训练分类和 C2I 模型,并在标准验证集上进行评估。对于 T2I 生成,我们使用从互联网上收集的 31,292k 张图片的相对较小集合对预训练模型进行微调。对于自然语言处理,我们使用 SlimPajama(【43,Slimpajama-dc: Understanding data combinations for llm training,2024】)的 5B token 子集来训练模型。

E 额外的实现细节

图像分类。在训练 DeiT 时,我们用平均池化替换了类别 token,并在相同的设置下训练所有基线,以确保公平比较。我们还遵循先前的工作,添加了核大小为 3 的 CPE(【10,Conditional positional encodings for vision transformers,2021】),以进行公平比较。对于 VLT,我们严格遵循(【22,Breaking the low-rank dilemma of linear attention,CVPR 2025】)中的设置。所有模型都训练 300 个 epoch,批大小为 1024,峰值学习率为 1e-3。对于输入大小为 224 的模型,我们将输入大小填充到 256,以便更好地分割头。对于 DeiT 模型,头数 M 设置为 16。对于 VLT 模型,两个线性注意力层的序列长度为 {3136, 784},因此我们分别为这两层设置头数 M 为 {49, 16}。

F 完整的实验结果

F.1 图像生成

DiT 和 DiG 模型的完整结果。我们在表 10 和表 9 中展示了 DiT 和 DiG 模型的完整结果。我们在图 7 中提供了 SANA-MHLA 的更多生成结果。

与其他方法的比较。我们还提供了与其他近期线性注意力方法在图像生成任务上的更全面比较(【50,Lit: Delving into a simple linear diffusion transformer for image generation,2025】),并报告了 MHLA 在三次独立运行中的均值和标准差,以证明我们结果的稳定性。相应的结果总结在表 11 中。

图7 我们微调后的SANA-MHLA模型生成的更多结果。
图7 我们微调后的SANA-MHLA模型生成的更多结果。

表9 DiT-XL/2上MHLA的快速适应结果,有无指导。
表9 DiT-XL/2上MHLA的快速适应结果,有无指导。

表10 不同模型上不同注意力类型的比较。
表10 不同模型上不同注意力类型的比较。

表11 与LiT的比较。我们报告了MHLA三次独立运行的FID分数(均值 ± 标准差),以证明结果的稳定性。
表11 与LiT的比较。我们报告了MHLA三次独立运行的FID分数(均值 ± 标准差),以证明结果的稳定性。

F.2 CPE 和输出门控的消融实验

辅助模块的效果分析。我们对 DiT-S 模型中 CPE 和输出门控与 MHLA 结合的效果进行了详细分析,如表 12 所示。我们的发现表明,在较小的模型中,CPE 和输出门控作为 MHLA 的正交优化,在模型尺寸不足时有效地增强了表达能力。然而,我们在表 3a 中的实验表明,随着模型尺寸的增加,CPE 和输出门控带来的性能增益减小。在 DiTXL 模型中,单独添加 CPE 实际上导致了性能下降。相比之下,无论模型大小如何,MHLA 始终能在表达能力上提供显著的改进。

表12 MHLA与CPE和输出门控的消融研究。
表12 MHLA与CPE和输出门控的消融研究。

F.3 在更高分辨率下的分类结果

高分辨率下的有效性验证。我们进一步在 384×384 和 512×512 分辨率下进行了额外的实验,使用 DeiT-T 模型来验证 MHLA 在高分辨率分类任务上的有效性。结果如表 13 所示。

表13 DeiT-T在有无MHLA情况下的高分辨率分类准确率。
表13 DeiT-T在有无MHLA情况下的高分辨率分类准确率。

F.4 扩展性分析

不同序列长度和头数下的吞吐量。在本节中,我们进行经验性研究,以评估 MHLA 在不同任务中,在不同序列长度 N 和 token 级别头数 M 下的吞吐量。表 14 中的结果表明,当满足 M2 < N 时,MHLA 仅引入可忽略的开销,而较大的 M 会导致更明显的开销。

M的选择。然而,我们在表 7b 中的消融研究已经证明,选择 M 使得 M2 < N 足以实现强大的性能。

表14 MHLA在不同序列长度N和token级头数M下的性能分析结果。左:DiT-S/2。右:DeiT-S/16。
表14 MHLA在不同序列长度N和token级头数M下的性能分析结果。左:DiT-S/2。右:DeiT-S/16。

G 术语和计算概念的澄清

新术语的定义。在本节中,我们为我们方法中使用的术语提供正式定义。这些术语描述了 MHLA 中新颖的计算行为,这些行为在先前的线性注意力公式中没有直接的类似物。

G.1 概念 1:查询条件 (query-conditioned)

动态上下文聚合。短语“查询条件”描述了一种机制,其中上下文信息的聚合是动态的,并且针对每个查询实例是特定的,这与标准线性注意力中的固定递归不同。

具体操作流程。具体来说,该过程操作如下:
* 每个查询 token 都与一个唯一的混合系数向量相关联。
* 这些系数用于为每个查询位置独立地加权和聚合所有局部的 KV 摘要。
因此,适应是按查询发生的,而不是全局的或通过共享的递归规则。

G.2 概念 2:KV 摘要 vs. 隐藏状态

与隐藏状态的区别。我们引入术语 KV 摘要,以严格区分我们的方法与传统线性注意力论文中的隐藏状态。虽然 KV 摘要在符号上可能看似与隐藏状态相似,但底层的计算和依赖图在两个关键方面存在结构性差异:
* 独立计算。与传统线性注意力中 $h_t$ 依赖于 $h_{t-1}$ 的严格递归链不同,MHLA 独立计算每个全局 KV 摘要($S_g$),消除了跨位置的状态传播。
* 多对一聚合。虽然传统状态是通过前一步的一对一更新得出的,但 MHLA 遵循一种多对一的聚合模式,其中每个 $S_g$ 是使用特定的混合系数从所有局部摘要计算出来的。

更高的表达力。通过避免隐藏状态固有的对历史的僵化继承,MHLA 的 KV 摘要实现了更大的表达能力和灵活性。

H LLM 的使用

仅作为写作辅助。我们仅使用大型语言模型(LLM)作为写作辅助,以润色手稿的清晰度和可读性。具体来说,我们使用基于 LLM 的工具来(i)为保持学术风格一致性而优化语法和措辞,(ii)改善章节之间的逻辑流程,以及(iii)精简过于冗长的段落。LLM 没有产生任何新的研究思想、实验设计或结果;所有的科学贡献、方法论开发和实验分析都是由作者构思和执行的。