XAttention: Block Sparse Attention with Antidiagonal Scoring
XAttention: Block Sparse Attention with Antidiagonal Scoring
作者/机构: Ruyi Xu * 1, Guangxuan Xiao * 2, Haofeng Huang 1, Junxian Guo 3, Song Han 2 4
GitHub: https://github.com/mit-han-lab/x-attention
A1 主要贡献
核心问题: 长上下文Transformer模型(LCTMs)在处理视频理解、视频生成等需要长序列信息的任务中至关重要,但其注意力机制的二次方复杂度导致计算成本高昂,特别是在预填充(pre-filling)阶段形成了性能瓶颈,阻碍了其在现实世界中的应用。
研究目标: 旨在设计一种能够显著加速长上下文Transformer模型推理,同时不牺牲模型准确性的块稀疏注意力机制,从而真正释放其在实际应用中的潜力。
现有方法的不足: 块稀疏注意力通过将计算集中在关键区域来降低成本,但现有方法在平衡准确性和效率方面存在困难。其主要问题在于,用于识别和确定重要注意力块的机制要么计算开销过大,抵消了稀疏化带来的收益,要么不够精确,导致模型性能下降。
核心创新点 (XAttention):
本文提出了XAttention,一个即插即用的框架,其核心创新在于发现注意力矩阵中反向对角线(从左下到右上)值的总和可以作为块重要性的一个强大且计算高效的代理指标。与依赖于计算密集型和信息有损的解决方案(如token池化)来识别重要块的现有方法不同,XAttention利用这种简单的评分机制,提供了一种更直接、更快速、更准确的方法来识别关键的注意力块。
实现方式与效果:
通过这种反向对角线评分算法,XAttention能够积极地发现并剪枝非必要的计算,实现高稀疏度而不牺牲准确性。该方法通过三个步骤优化注意力计算:
1. 带步长的反向对角线评分 (Strided Antidiagonal Scoring):对每个块(例如8x8)沿着其带步长的反向对角线(例如步长为4)求和来评分。
2. 块选择 (Block Selection):根据评分选择得分高的块。
3. 块稀疏注意力 (Block Sparse Attention):仅在被选中的块上计算注意力,从而实现显著的计算节省。
实验验证:
在包括RULER和LongBench(语言)、VideoMME(视频理解)和VBench(视频生成)在内的多个具有挑战性的长上下文基准上进行了全面评估。结果表明,XAttention在实现与全注意力相当的准确性的同时,带来了显著的计算增益,在注意力计算方面实现了高达13.5倍的加速。这些结果突显了XAttention解锁块稀疏注意力实用潜力的能力,为LCTM在现实世界应用中的可扩展和高效部署铺平了道路。
A2 方法细节
本节介绍我们的方法XAttention。XAttention算法包含三个主要部分:(1) 注意力图块的重要性预测,(2) 重要注意力块的选择,以及 (3) 注意力头最小阈值的预测。
2.1. 重要性预测
现有池化方法的局限性。注意力图的内在稀疏性要求一种强大的策略来预测注意力块的重要性。虽然像MInference【17,Minference 1.0: Accelerating pre-filling for longcontext llms via dynamic sparse attention,2024】和FlexPrefill【1,Flexprefill: A context-aware sparse attention mechanism for efficient long-sequence inference,2025】等方法利用池化和“垂直斜线检测”的组合,但我们的消融研究表明,仅依赖平均或总和池化会产生不准确的预测。当一个块内只存在少数显著的垂直或斜线模式时,池化方法尤其无效,无法捕捉到这些关键的重要性指标。
现有搜索方法的挑战。MInference和FlexPrefill试图通过分析输入查询的最后一部分来识别重要的“垂直和斜线索引”以克服这一限制。然而,这种方法面临两个主要挑战:首先,重要的注意力模式可能不会持续存在于最后的查询段中;其次,搜索算法本身引入了巨大的计算开销(如图5所示)。
提出的反向对角线选择法。从根本上说,一个有效的块重要性预测方法应该能够自动且稳健地识别重要模式,包括关键的垂直和斜线模式。为实现这一目标,我们提出了反向对角线选择法。在每个大小为B的块内,我们使用步长S沿反向对角线选择元素(如图1所示)。这些选定元素的总和作为相应注意力块整体重要性的代理。
方法的有效性分析。该方法的有效性可以从两个角度来理解:(1) 信息保持:这种选择策略确保了所有令牌的信息都被考虑在内,因为每个令牌都至少对一个反向对角线总和有贡献。(2) 模式检测:如图2所示,反向对角线与块内的每一种可能的垂直和斜线模式相交。XAttention的反向对角线模式与块内的垂直和斜线模式均相交,从而能够有效检测这些模式并指导有效的稀疏注意力计算。这确保了在重要性估计过程中不会错过任何关键模式。
2.2. 阈值块选择
稀疏注意力块选择算法。基于反向对角线评分模式,我们提出了以下稀疏注意力块选择算法。设S为步长,B为稀疏注意力块的大小。该过程始于反向对角线求和,即我们在注意力图的每个S×S块内沿反向对角线选择元素,并计算每个反向对角线上这些元素的总和。随后,我们通过对这些反向对角线总和应用softmax函数来进行softmax归一化,从而得到反向对角线上的概率分布。最后,对于块选择,我们使用find_blocks函数来识别一个最小的块集合,这些块的反向对角线概率累积和超过预定义的阈值τ。形式上,这可以表示为:
其中A是注意力图,B是一个块的集合,|B|代表集合中块的数量。这个过程根据反向对角线评分模式和指定的阈值,有效地确定了注意力图中最重要的一些块。
算法1 块选择
需要:查询矩阵 Q ∈ R^(L×d),键矩阵 K ∈ R^(L×d),块大小 B,步长 S,头维度 dh,阈值 τ
确保:稀疏掩码 M
1: NB ← ⌊L/B⌋ {块的数量}
2: for b = 0 to NB - 1 do
3: Q_slice ← Q[bB : (b + 1)B, :] {提取 Q 块}
4: Q_reshaped ← []
5: for i = S - 1 down to 0 do
6: Q_reshaped.append(Q_slice[i :: S, :]) {以步长 S 沿反向对角线重塑}
7: end for
8: K_reshaped ← []
9: for i = 0 to S - 1 do
10: K_reshaped.append(K[i :: S, :]) {以步长 S 沿反向对角线重塑}
11: end for
12: A_approx ← Softmax(Q_reshaped * K_reshaped^T / √(dh·S)) {近似注意力分数}
13: M_b ← find_blocks(A_approx, τ) {根据阈值找到块}
14: end for
15: M ← concatenate(M_0, M_1, . . . , M_(NB-1)) {连接块掩码}
2.3. 最小阈值预测
动态规划方法。我们提出一种动态规划方法来确定每个注意力头的最优阈值。先前的研究表明,不同的注意力头表现出不同的稀疏级别和重要性。因此,为单个头动态调整阈值以优化准确性与计算效率之间的平衡是有益的。
问题形式化。考虑一个有H个注意力头的模型。我们定义一个动态规划表$D[h][m]$,其中$h \in \{1, 2, . . . , H\}$表示第h个头,而$m \in \{1, 2, . . . , M \}$表示已进行的阈值调整次数。$D[h][m]$存储了在前h个头上精确进行m次阈值调整时可达到的最佳性能。
动态规划。我们的目标是为每个头找到最优阈值,使得它们的共同贡献在最小化计算的同时最大化准确性。DP表的递推关系为:
其中$P(h, m)$表示当第h个头的阈值第m次调整时的模型性能。这对应于在优化过程中,相对于状态$D[h-1][m-1]$,将第h个头的阈值降低一步后的模型性能。
阈值调整策略。我们在每一步将每个头的阈值降低10%:
$$t_{h}(m)=t_{h}(m-1) \times 0.9$$这确保了在保持每个头对准确性的贡献的同时,逐步减少计算量。
方法的可选性。需要注意的是,这种动态阈值预测方法可以进一步优化XAttention的稀疏性,但并非强制性组件。我们将在消融研究中展示详细结果。
A4 实验环境与结果
实验环境
-
模型:
- 自然语言任务: Llama-3.1-8B-Instruct【10,The llama 3 herd of models,2024】。对该模型应用了精确阈值预测方法。
- 视频理解: Qwen2-VL-7B-Instruct【32,Qwen2-vl: Enhancing vision-language model’s perception of the world at any resolution,2024】。
- 视频生成: HunyuanVideo【18,Hunyuanvideo: A systematic framework for large video generative models,2025】。
-
基线模型:
- 密集注意力: FlashAttention【8,FlashAttention-2: Faster attention with better parallelism and work partitioning,2023】,通过FlashInfer【40,Cascade inference: Memory bandwidth efficient shared prefix batch decoding,2024】框架实现。
- 稀疏注意力:
- MInference【17,Minference 1.0: Accelerating pre-filling for longcontext llms via dynamic sparse attention,2024】(使用官方配置,所有注意力头采用“Vertical-Slash”稀疏模式)。
- FlexPrefill【1,Flexprefill: A context-aware sparse attention mechanism for efficient long-sequence inference,2025】(超参数设置为$\gamma = 0.95$和$\tau = 0.1$)。
- SeerAttention【12,Seerattention: Learning intrinsic sparse attention in your llms,2024】(在Gare权重上进行了预训练)。
-
数据集:
-
自然语言:
- RULER【15,Ruler: What’s the real context size of your long-context language models?,2024】:一个专为评估LLM长上下文能力的合成基准。
- LongBench【2,Longbench: A bilingual, multitask benchmark for long context understanding,2023】:包含真实世界的长上下文任务。
-
视频理解: Video-MME【11,Video-mme: The first-ever comprehensive evaluation benchmark of multi-modal llms in video analysis,2024】,包含900个视频,总时长254小时。
- 视频生成: VBench【16,VBench: Comprehensive benchmark suite for video generative models,2024】中的946个GPT增强的文本提示。
-
-
硬件与软件: 论文未明确提供GPU型号、数量、网络配置等硬件信息,也未提供具体的软件库版本。致谢部分提到NVIDIA捐赠了DGX服务器。
实验结果
准确性结果
-
RULER基准测试:
- 实验内容: 在Llama-3.1-8B-Instruct上,使用动态规划方法进行最小阈值预测(步长S=8和S=16,最大调整次数M=1000),得到平均阈值为0.8。
- 实验结果: 如表1所示,随着上下文长度增加,MInference和SeerAttention性能显著下降。相比之下,XAttention不仅超过了FlexPrefill,在多个序列长度上甚至优于全注意力,显示了其处理超长上下文的鲁棒性。
表1. 不同方法在Llama-3.1-8B-Instruct和RULER上不同序列长度的准确性比较。XAttention配置为步长S=8和S=16,并使用精确预测的最小阈值。 -
LongBench基准测试:
- 实验内容: 在Llama-3.1-8B-Instruct上,使用与RULER相同的配置,在LongBench的真实世界任务上进行评估。
- 实验结果: 如表2所示,XAttention在所有任务中取得了最高的平均分,证明了其在实际场景中的有效性。其在单个任务上的性能与全注意力相当,表明该方法在提高效率的同时保持了准确性。
表2. 不同注意力方法在Llama-3.1-8B-Instruct模型上于真实世界LongBench任务的比较。XAttention配置为步长8和精确预测的最小阈值,取得了所有基线中最好的平均分。 -
视频理解 (Video-MME):
- 实验内容: 在QwenVL-2-7B模型上应用XAttention(步长S=16,阈值$\tau=0.9$)。
- 实验结果: 如表3所示,MInference和FlexPrefill在长视频任务上表现不佳。XAttention在所有稀疏注意力方法中取得了最好的平均分,甚至在长视频上超过了FlashAttention。
表3. 不同方法在Video-MME视频理解任务中于QwenVL-2-7B上的比较。XAttention配置为步长S=16,阈值τ=0.9。XAttention在长视频任务上优于全注意力,并在所有稀疏注意力方法中取得了最佳平均性能。 -
视频生成 (VBench):
- 实验内容: 使用HunyuanVideo模型(DiT架构,非因果注意力)进行视频生成。由于基线不支持非因果注意力,仅与全注意力进行比较。XAttention配置为S=8,$\tau=0.9$和$\tau=0.95$。
- 关键发现与解决方案: 直接应用XAttention会导致生成视频的布局轻微偏移。受扩散模型研究启发【38,Fastcomposer: Tuning-free multi-subject image generation with localized attention,2023c;21,Distrifusion: Distributed parallel inference for high-resolution diffusion models,2024】,引入了“预热”阶段:前5个去噪步骤使用全注意力,之后切换到XAttention。
- 实验结果: 如表4和图3所示,加入预热后,XAttention生成的视频与全注意力基线在视觉上高度保真。定量指标(PSNR高达23.5, SSIM高达0.822, LPIPS低至0.155)也证实了这一点。两种$\tau$设置都实现了超过50%的稀疏度。
表4. 在HunyuanVideo模型上应用XAttention于VBench基准测试的定量结果,使用了5步全注意力预热。较高的(τ)以略微降低的稀疏度(较高的密度)为代价,产生更好的保真度(更高的PSNR,更高的SSIM,更低的LPIPS)。两种(τ)设置都展示了与全注意力基线的高度相似性。 图3. 使用VBench数据集中第一个提示的视频生成结果的定性比较。各行显示了使用以下方法生成的视频帧:(1) 全注意力(基线),(2) XAttention无预热(τ=0.95),(3) XAttention有5步预热(τ=0.9),以及(4) XAttention有5步预热(τ=0.95)。带预热的XAttention实现了与全注意力基线高度的视觉保真度。
效率结果
-
注意力加速:
- 实验内容: 测量在8k到256k token序列长度范围内的预填充(prefill)阶段的注意力加速比。
- 实验结果: 如图4所示,XAttention在各种上下文长度上均表现出优异的加速效果。在256k上下文长度时,XAttention实现了最高13.5倍的预填充注意力加速(对应的密度见表5)。
图4. 不同注意力方法在不同上下文长度下的加速比比较,相对于FlashInfer实现的FlashAttention。XAttention持续优于其他稀疏注意力方法,在256K token时达到13.5倍的加速。 表5. 不同上下文长度下的密度。步长S=8实现较低的稀疏度,并且随着上下文长度的增加,稀疏度通常会增加(密度降低)。 -
注意力时间分解:
- 实验内容: 将计算时间分解为模式选择和稀疏注意力计算两个部分。
- 实验结果: 如图5所示,XAttention的反向对角线模式及其高效的块选择算法,使得其模式选择时间显著快于依赖垂直斜线索引搜索的MInference和FlexPrefill(分别快24.9倍和5.9倍)。同时,更高的模式选择准确性带来了更低的注意力密度,进一步加速了稀疏注意力计算本身。
图5. 预填充注意力时间分解。Xattention显著减少了模式选择时间,同时保持了密度,与现有方法相比实现了大幅加速。
消融研究
- 反向对角线模式的有效性: 与随机模式和对角线模式相比,反向对角线模式在所有任务中均以最低的密度实现了最高的准确性(见表6)。
- 步长大小的影响: 比较了步长S=4, 16, 64。结果表明,过大的步长(如64)无法准确检测到斜线注意力模式,导致性能下降(见表7)。
- 块选择策略: 比较了Top-K、Top-Ratio和本文的阈值块选择(动态稀疏度)。结果显示,基于阈值的方法在不同输入序列长度下,在计算和准确性之间取得了最佳平衡(见表8)。
- 最小阈值预测: 将动态预测的阈值(平均0.8)与固定的$\tau=0.9$进行比较。动态方法在RULER基准上实现了更低的密度和更高的准确性,证明了其有效性(见表9)。
A7 补充细节 (相关工作)
4.1. 长上下文大语言模型
扩展上下文能力的方法。工程和算法的进步扩展了大语言模型(LLM)的上下文长度能力。主要有两种方法:(1)编译大型长文本数据集进行持续预训练或微调【30,Yarn: Efficient context window extension of large language models,2023;5,Extending context window of large language models via positional interpolation,2023】;(2)利用外部记忆或检索增强技术来增强长程上下文处理能力【4,Memory transformer,2021;35,Infllm: Training-free longcontext extrapolation for llms with an efficient context memory,2024a;33,Retrieval head mechanistically explains long-context factuality,2024】。这些进步使得LLM能够处理需要对扩展序列进行推理的日益复杂的任务。
4.2. 稀疏注意力
稀疏注意力的内在特性与挑战。LLM核心的注意力机制表现出固有的稀疏性,意味着许多注意力权重可以被忽略不计而不会显著降低性能【6,Generating long sequences with sparse transformers,2019a】。随着上下文长度的增加,这种稀疏性变得更加明显,为优化推理速度提供了机会。然而,这种稀疏性的动态和输入依赖性(它在不同输入、注意力头甚至层之间变化)对有效利用构成了重大挑战。
现有稀疏注意力方法。像Sparse Transformer【7,Generating long sequences with sparse transformers,2019b】、LongFormer【3,Longformer: The long-document transformer,2020】、BigBird【42,Big bird: Transformers for longer sequences,2020】和Selective Attention【20,Selective attention improves transformer,2024】等方法通过局部或基于块的注意力来降低复杂性,但通常需要重新训练,限制了实用性。H2O【44,H2o: Heavy-hitter oracle for efficient generative inference of large language models,2023】和TOVA【28,Transformers are multi-state rnns,2024】根据查询模式丢弃令牌。StreamingLLM【37,Efficient streaming language models with attention sinks,2023b】保留初始和最近的令牌以实现一致的延迟和内存使用,从而能够处理比预训练长度更长的序列。基于检索头的方法【33,Retrieval head mechanistically explains long-context factuality,2024;39,Duoattention: Efficient long-context llm inference with retrieval and streaming heads,2024b】通过将计算集中在关键的检索头上,加速模型解码。
针对预填充阶段的加速方法。为了加速预填充阶段,最近的方法采用了稀疏注意力模式。MInference【17,Minference 1.0: Accelerating pre-filling for longcontext llms via dynamic sparse attention,2024】和FlexPrefill【1,Flexprefill: A context-aware sparse attention mechanism for efficient long-sequence inference,2025】都利用模式选择算法在预填充期间实现显著加速。然而,这些选择算法的开销仍然是一个瓶颈。SeerAttention【12,Seerattention: Learning intrinsic sparse attention in your llms,2024】通过预训练和微调门控参数实现高稀疏度,提高了效率同时保持了低困惑度。然而,它需要昂贵的训练过程,并且在下游任务上表现有限。因此,需要一种无训练且选择算法开销极小的方法,以解决与不断增长的上下文长度相关的日益增长的预填充时间。
4.3. LLM推理加速
系统级优化。已开发出多种技术来加速LLM推理。系统级解决方案侧重于优化原始的注意力计算,以更好地利用硬件特性。著名的例子包括FlashAttention【9,FlashAttention: Fast and memory-efficient exact attention with IO-awareness,2022;8,FlashAttention-2: Faster attention with better parallelism and work partitioning,2023】,它优化了内存访问模式以实现更快的注意力计算,以及RingAttention【25,Ring attention with blockwise transformers for near-infinite context,2023】,它将注意力计算分布到多个设备上。其他系统级方法包括FlashDecoding【14,Flashdecoding++: Faster large language model inference on gpus,2024】和PagedAttention【19,Efficient memory management for large language model serving with pagedattention,2023】,它们分别专注于优化计算过程和KV缓存管理。
模型压缩技术。模型压缩技术,如量化,也被广泛用于减小模型大小和内存占用,从而实现更快的推理。例子包括SmoothQuant【36,SmoothQuant: Accurate and efficient post-training quantization for large language models,2023a】、AWQ【23,Awq: Activation-aware weight quantization for llm compression and acceleration,2024】和QServe【24,Qserve: W4a8kv4 quantization and system co-design for efficient llm serving,2024】,它们将模型权重和/或激活量化到更低的位宽,从而减少内存带宽需求并加速计算。
4.4. 近期工作
稀疏注意力领域的最新进展。最近,一些杰出的工作致力于推进稀疏注意力。Sparse VideoGen【34,Sparse videogen: Accelerating video diffusion transformers with spatial-temporal sparsity,2025】通过利用空间和时间头在保持生成质量的同时加速视频生成模型。NSA【41,Native sparse attention: Hardware-aligned and natively trainable sparse attention,2025】引入了一种原生可训练的稀疏注意力机制,用于高效的长上下文建模。MoBA【26,Moba: Mixture of block attention for long-context llms,2025】通过采用混合专家方法,解决了传统注意力机制的二次复杂度问题,而不依赖于强偏置结构(如sink或窗口注意力)。Fast Video Generation【43,Fast video generation with sliding tile attention,2025】通过滑动瓦片注意力减少了计算需求,该方法采用局部化的时空窗口代替全注意力计算。我们的工作与这些努力一致,旨在通过降低计算成本和实现高效部署来普及人工智能。
A5 结论
本文提出了XAttention,一个新颖的即插即用框架,用于加速Transformer模型中的长上下文推理。通过利用注意力矩阵中反向对角线总和可作为块重要性的稳健代理这一洞见,XAttention能有效识别并剪枝非必要块,在不牺牲准确性的前提下实现了显著的计算节省。我们在自然语言理解(RULER, LongBench)、视频理解(VideoMME)和视频生成(VBench)等具有挑战性的长上下文基准上的评估表明,XAttention在保持与全注意力相当性能的同时,注意力计算速度提升高达13.5倍。这些结果凸显了XAttention解锁块稀疏注意力实用潜力的能力,为长上下文Transformer模型在真实世界应用中的高效和可扩展部署铺平了道路。
💬 评论讨论
欢迎在这里分享您的想法和见解!