BLASST: Dynamic Blocked Attention Sparsity via Softmax Thresholding

作者/机构: Jiayi Yuan, Cameron Shinn, Kai Xu, Jingze Cui, George Klimiashvili, Guangxuan Xiao, Perkz Zheng, Bo Li, Yuxin Zhou, Zhouhai Ye, Weijie You, Tian Zheng, Dominic Brown, Pengbo Wang, Richard Cai, Julien Demouth, John D. Owens, Xia Hu, Song Han, Timmy Liu, Huizi Mao

A1 主要贡献

核心问题: 随着大语言模型(LLM)对长上下文处理能力的需求日益增长,标准注意力机制的计算和内存瓶颈问题变得愈发严重。注意力机制的计算复杂度和内存访问量均与序列长度n成二次方关系($O(n^2)$),这使得即使在最先进的硬件上部署长上下文模型也极具挑战性。尽管FlashAttention等技术通过分块和内核融合优化了内存带宽利用率,但它们仍需计算完整的注意力矩阵,并未解决根本的二次复杂度问题。

研究目标: 本文旨在提出一种简单而有效的稀疏注意力方法,以动态地修剪(prune)注意力矩阵中贡献可忽略的部分,从而加速长上下文推理,同时避免现有稀疏方法中存在的预计算开销大、依赖不准确的代理分数或模式僵化等问题。

创新点 (BLASST): 本文提出了BLASST(BLocked Attention Sparsity via Softmax Thresholding),一种无需任何预计算开销的插入式稀疏注意力方法。
1. 动态剪枝机制: BLASST的核心思想是在FlashAttention的逐块在线softmax计算过程中,利用已经计算出的信息来动态识别并跳过那些对最终输出贡献可忽略的注意力块。具体而言,它维护一个行方向的“运行最大值”(running maximum),如果一个新处理的块的局部最大值远小于这个运行最大值(由一个阈值$\lambda$控制),那么这个块在经过softmax归一化后的值将趋近于零。因此,可以安全地跳过对该块的三项昂贵操作:(1) softmax的指数计算,(2) 从HBM加载相应的Value块,以及(3) 后续的矩阵乘法。这个过程仅需每次块比较一次,几乎无延迟开销。

图1. BLASST概览。注意力矩阵的一行中的块被顺序处理。我们(1)像FlashAttention一样更新运行中的行最大值(m(j)),(2)为每个Sj块(QK⊤j)计算块最大值(m˜ (j)),以及(3)如果块最大值比运行最大值小超过输入阈值ln(λ),则跳过后续工作。完整细节见算法1。
  1. 为Prefill和Decode阶段优化的CUDA内核: 针对prefill(计算密集型)和decode(内存密集型)的不同特性,开发了专门的CUDA内核。

    • Prefill内核: 减少CUDA核心和张量核心的使用。
    • Decode内核: 减少内存带宽消耗。
      这些内核在现代GPU上(H200, B200)实现了显著的加速效果,prefill阶段在74.7%的稀疏度下加速1.62倍,decode阶段在73.2%的稀疏度下加速1.48倍。
  2. 自动化部署与性能增强技术:

    • 自动校准程序: 提出了一种自动校准程序,该程序揭示了最优阈值与上下文长度之间存在简单的反比关系($\lambda = a/L$),从而能够在不同场景下稳健部署,无需手动调优。
    • 稀疏感知训练: 作为一种自然扩展,本文探索了稀疏感知训练,证明模型可以通过训练来适应稀疏注意力模式,从而在保持高精度的同时达到更高的稀疏度。

本文贡献总结:

  1. 提出了一种无需预计算开销和代理分数的插入式方法,实现了最小的准确率损失。
  2. 为稳健、灵活和可扩展的部署提供了自动化的超参数选择和稀疏感知训练。
  3. 为prefill和decode阶段开发了基于FlashAttention的优化CUDA内核,性能卓越。

A3 相关工作

有效利用稀疏性的挑战。有效利用注意力稀疏属性需要在不引入昂贵的选择开销或重新训练的情况下,减少对不重要交互的计算或减少内存占用(例如KV缓存)。与以下相关工作相比,BLASST以一种无需训练的方式同时解决了计算和内存两个方面的问题。

2.1 计算优化的稀疏性

现有方法通过选择重要交互来减少注意力计算
* 静态模式方法:如Sparse Transformer 【【5】Child, R., Gray, S., Radford, A., and Sutskever, I. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.】、LongFormer 【【4】Beltagy, I., Peters, M. E., and Cohan, A. Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150, 2020.】 和 BigBird 【【36】Zaheer, M., Guruganesh, G., Dubey, K. A., Ainslie, J., Alberti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q., Yang, L., et al. Big bird: Transformers for longer sequences. Advances in neural information processing systems, 33:17283–17297, 2020.】通过局部或块状注意力降低了复杂度。
* 基于检索头的方法:如【【26】Wu, W., Wang, Y., Xiao, G., Peng, H., and Fu, Y. Retrieval head mechanistically explains long-context factuality. arXiv preprint arXiv:2404.15574, 2024.】和【【29】Xiao, G., Tang, J., Zuo, J., Guo, J., Yang, S., Tang, H., Fu, Y., and Han, S. Duoattention: Efficient long-context llm inference with retrieval and streaming heads. arXiv preprint arXiv:2410.10819, 2024b.】通过将计算集中在关键的检索头上,加速模型解码。
* 动态稀疏方法:MInference 【【14】Jiang, H., Li, Y., Zhang, C., Wu, Q., Luo, X., Ahn, S., Han, Z., Abdi, A. H., Li, D., Lin, C.-Y., et al. Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention. Advances in Neural Information Processing Systems, 37:52481–52515, 2024.】使用预计算的重要性分数,XAttention 【【30】Xu, R., Xiao, G., Huang, H., Guo, J., and Han, S. Xattention: Block sparse attention with antidiagonal scoring. arXiv preprint arXiv:2503.16428, 2025.】对反斜对角线上的块进行排序,FlexPrefill 【【15】Lai, X., Lu, J., Luo, Y., Ma, Y., and Zhou, X. Flexprefill: A context-aware sparse attention mechanism for efficient long-sequence inference. arXiv preprint arXiv:2502.20766, 2025.】提供编译器支持的灵活块模式。这些方法对prefill阶段有效,但其预计算和调度开销可能会限制实际的加速效果。
* 训练辅助的稀疏性:SeerAttention 【【10】Gao, Y., Guo, S., Cao, S., Xia, Y., Cheng, Y., Wang, L., Ma, L., Sun, Y., Ye, T., Dong, L., et al. Seerattention-r: Sparse attention adaptation for long reasoning. arXiv preprint arXiv:2506.08889, 2025.】通过(预)训练门控机制来诱导高稀疏性,提高了效率但增加了训练成本,并且在下游任务上的表现不一。

与SpargeAttention的比较。SpargeAttention 【【38】Zhang, J., Xiang, C., Huang, H., Wei, J., Xi, H., Zhu, J., and Chen, J. Spargeattn: Accurate sparse attention accelerating any model inference. arXiv preprint arXiv:2502.18137, 2025.】的设计与BLASST最为接近。但我们有三个关键区别:(1) BLASST通过专门的内核优化了prefill和decode两个阶段,而SpargeAttention仅针对prefill;(2) 我们直接使用已计算的统计数据进行零开销的跳过决策,而SpargeAttention使用一个独立的预测步骤;(3) 我们的decode内核跳过了从HBM加载Value的操作,在节省计算的同时解决了内存密集型瓶颈。此外,我们还提供了自动校准和稀疏感知训练。

2.2 内存优化的稀疏性

Token/KV稀疏性专注于减少内存占用和解码时成本。H2O 【【39】Zhang, Z., Sheng, Y., Zhou, T., Chen, T., Zheng, L., Cai, R., Song, Z., Tian, Y., Re, C., Barrett, C., et al. H2o: ´ Heavy-hitter oracle for efficient generative inference of large language models. Advances in Neural Information Processing Systems, 36:34661–34710, 2023.】、TOVA 【【19】Oren, M., Hassid, M., Yarden, N., Adi, Y., and Schwartz, R. Transformers are multi-state rnns. arXiv preprint arXiv:2401.06104, 2024.】 和 InfLLM 【【27】Xiao, C., Zhang, P., Han, X., Xiao, G., Lin, Y., Zhang, Z., Liu, Z., and Sun, M. Infllm: Training-free long-context extrapolation for llms with an efficient context memory. Advances in Neural Information Processing Systems, 37: 119638–119661, 2024a.】根据查询模式丢弃token。StreamingLLM 【【28】Xiao, G., Tian, Y., Chen, B., Han, S., and Lewis, M. Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453, 2023.】保留初始和最近的token以保持一致的延迟和内存使用。Quest 【【25】Tang, J., Zhao, Y., Zhu, K., Xiao, G., Kasikci, B., and Han, S. Quest: Query-aware sparsity for efficient long-context llm inference. arXiv preprint arXiv:2406.10774, 2024.】根据当前查询条件剪枝token,Rectified Sparse Attention 【【24】Sun, Y., Ye, T., Dong, L., Xia, Y., Chen, J., Gao, Y., Cao, S., Wang, J., and Wei, F. Rectified sparse attention. arXiv preprint arXiv:2506.04108, 2025.】自适应地选择token以在高稀疏度下保持准确性,RocketKV 【【3】Behnam, P., Fu, Y., Zhao, R., Tsai, P.-A., Yu, Z., and Tumanov, A. Rocketkv: Accelerating long-context llm inference via two-stage kv cache compression. arXiv preprint arXiv:2502.14051, 2025.】通过选择性驱逐来压缩KV缓存,最近的KV压缩技术 【【16】Łancucki, A., Staniszewski, K., Nawrot, P., and Ponti, E. M. ´ Inference-time hyper-scaling with kv cache compression. arXiv preprint arXiv:2506.05345, 2025.】进一步扩展了有效上下文;TidalDecode 【【32】Yang, L., Zhang, Z., Chen, Z., Li, Z., and Jia, Z. Tidaldecode: Fast and accurate llm decoding with position persistent sparse attention. arXiv preprint arXiv:2410.05076, 2024.】通过位置持久的模式稳定了解码效率。这些方法主要通过在解码阶段进行KV剪枝/压缩来优化内存,而BLASST则直接在prefill和decode阶段减少计算量,并且无需训练。

2.3 新的注意力变体

除了上述方法,还有其他注意力机制。包括滑动窗口注意力(Sliding Window Attention) 【【4】Beltagy, I., Peters, M. E., and Cohan, A. Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150, 2020.】、线性或门控注意力(Linear or Gated Attention) 【【20】Qiu, Z., Wang, Z., Zheng, B., Huang, Z., Wen, K., Yang, S., Men, R., Yu, L., Huang, F., Huang, S., et al. Gated attention for large language models: Non-linearity, sparsity, and attention-sink-free. arXiv preprint arXiv:2505.06708, 2025.】以及状态空间模型(SSM) 【【11】Gu, A. and Dao, T. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.】。原生稀疏注意力(Native Sparse Attention, NSA) 【【34】Yuan, J., Gao, H., Dai, D., Luo, J., Zhao, L., Zhang, Z., Xie, Z., Wei, Y., Wang, L., Xiao, Z., et al. Native sparse attention: Hardware-aligned and natively trainable sparse attention. arXiv preprint arXiv:2502.11089, 2025.】和DeepSeek稀疏注意力(DSA) 【【8】DeepSeek-AI. Deepseek-v3.2-exp: Boosting long-context efficiency with deepseek sparse attention, 2025.】虽然在某些情况下有效,但通常需要架构更改或重新训练。相比之下,BLASST是一种后训练方法,它无需代理分数或复杂的预计算即可加速prefill和decode,并能与FlashAttention实现无缝集成。

A2 方法细节

3.1 通过运行最大值进行注意力剪枝

BLASST的核心洞察。在于观察到在FlashAttention计算注意力分数的过程中,许多块经过softmax归一化后对最终输出的贡献微乎其微。我们的方法在正向传播过程中动态地识别并跳过这些块,无需预计算或代理分数。

3.1.1 关键洞察

标准注意力机制的Softmax计算。在标准的注意力机制中,softmax操作的计算公式如下:

$$\text{Attention}(Q, K, V) = \text{softmax} \left( \frac{QK^\top}{\sqrt{d_k}} \right) V$$

剪枝条件的提出。在FlashAttention的逐块计算过程中,我们维护一个跨块的运行最大值$m_i^{(j)}$。如果某个块的局部最大值$\tilde{m}_i^{(j)}$显著小于当前的运行最大值,即对于某个阈值$\lambda$,满足$\tilde{m}_i^{(j)} - m_i^{(j)} < \ln(\lambda)$,那么在指数化之后:

$$\exp(\tilde{m}_{i}^{(j)}-m_{i}^{(j)}) < \lambda \approx 0$$

由于最大值被$\lambda$所限制,该块对最终注意力输出的贡献将微不足道,因此我们可以完全跳过其计算。

剪枝条件的直观解释。这个标准遵循一个三步近似法。首先,每个分数$S_{ij}$的理想重要性是其相对于(未知的)全局最大值的值。其次,动态计算真实的最大值成本太高,所以我们使用运行最大值作为一个易于处理的代理,并将$S_{ij}$与之比较。第三,为了在内核内实现高效的块级决策,我们将token级的$S_{ij}$替换为块的局部最大值,从而得到廉价的条件(块最大值 - 运行最大值)< $\ln(\lambda)$。

3.1.2 算法设计

算法1展示了我们修改后的FlashAttention前向传播过程。关键的修改是引入了一个动态剪枝条件,从而节省了计算和内存带宽。当$\tilde{m}_i^{(j)} - m_i^{(j)} < \ln(\lambda)$(第7行)时,我们跳过以下操作,从而实现开销节省:
1. 计算节省(CUDA核心):计算$\tilde{P}_{ij}$所需的昂贵的exp(·)操作,每个元素需要多个指令:MUFU.EX2(指数)、FMUL(乘法)和FADD(加法)。我们还跳过了用于归一化注意力权重的行求和规约操作(FADD指令)。对于一个典型的块,这可以节省数千条CUDA核心指令。
2. 计算节省(张量核心):矩阵乘法$\tilde{P}_{ij}V_j$。在prefill阶段,内核是计算密集型的,避免这些MMA操作可以提供显著的加速。
3. 内存带宽节省:从HBM向SRAM加载Value块$V_j$。这在decode阶段尤其关键,因为此时注意力是内存密集型的。

算法 1 FlashAttention with BLASST
输入: 查询块 {Qi} i=1..Tr, 键块 {Kj} j=1..Tc, 值块 {Vj} j=1..Tc, 阈值 λ
输出: 输出块 {Oi} i=1..Tr
1: for i = 1 to Tr do
2:    初始化 mi(0) = -∞, O(0)i = 0, l(0)i = 0
3:    for j = 1 to Tc do
4:       计算 Sij = Qi * K⊤j   ▷ 注意力分数
5:       m˜i(j) = rowmax(Sij)  ▷ 局部最大值
6:       mi(j) = max(m(j-1)i, m˜(j)i) ▷ 运行最大值
7:       if m˜i(j) - mi(j) < ln(λ) then
8:          continue           ▷ 跳过此块
9:       end if
10:      P˜ij = exp(Sij - mi(j)) ▷ 计算注意力权重
11:      l(j)i = e^(m(j-1)i - m(j)i) * l(j-1)i + rowsum(P˜ij)
12:      O(j)i = e^(m(j-1)i - m(j)i) * O(j-1)i + P˜ij * Vj
13:   end for
14:   Oi = O(Tc)i / l(Tc)i       ▷ 最终归一化
15: end for
16: return {Oi} i=1..Tr

方法的直接效果。我们的方法通过在正向传播过程中动态识别并跳过可忽略的注意力块,直接减少了总计算量。这种简单而有效的修改对现有的FlashAttention实现只需做最小的改动,同时提供了显著的计算节省。

3.2 最佳稀疏度的校准

选择合适的阈值$\lambda$以平衡稀疏度与准确性是一个关键挑战。为了理解这种关系,我们在Llama-3.1-8B模型上,使用RULER基准测试中具有挑战性的子集(NIAH MULTI, VT, FWE),对8K到64K token的上下文长度进行了实验。

稀疏度决定准确性。图2(左)显示了相对准确率随稀疏率变化的下降情况。为了公平比较,我们将每条曲线都用全注意力结果进行了归一化。值得注意的是,所有曲线都表现出一致的退化模式:在稀疏度达到60-70%之前,性能保持稳定,超过此范围后准确性急剧下降。这种在不同任务和序列长度上的一致性表明,准确性的下降主要由稀疏率本身决定,而不是数据集类型或序列长度。

阈值校准至关重要。为了获得一致的性能,我们应该保持固定的稀疏率而不是固定的阈值。然而,图2(右)显示,要达到75%的稀疏度,8K上下文需要$\lambda \approx 1e-4$,而64K上下文则仅需$1e-5$。这使得自适应校准成为必要。重要的是,通过校准来设定固定的稀疏度,用户可以控制和预测计算加速效果,因为性能增益与达到的稀疏度水平成可预测的比例关系。

图2. (左) 不同数据集和上下文长度下的相对准确率下降情况显示出一致的退化模式。所有曲线都已对其初始准确率进行归一化。(右) 不同序列长度下,阈值与达到的稀疏度水平之间的关系,表明需要进行阈值校准以在变化的上下文中保持固定的稀疏度。
图2. (左) 不同数据集和上下文长度下的相对准确率下降情况显示出一致的退化模式。所有曲线都已对其初始准确率进行归一化。(右) 不同序列长度下,阈值与达到的稀疏度水平之间的关系,表明需要进行阈值校准以在变化的上下文中保持固定的稀疏度。

阈值与上下文长度的逆比例关系。通过实证分析,我们发现最佳阈值与上下文长度$L$遵循反比关系:

图片描述
图片描述

其中$a$是一个模型特定的常数。这种反比关系有其理论基础:由于注意力分数在行上归一化后总和为1,较长的序列每个token的平均分数较低,因此需要相应较小的阈值。若不进行校准,固定的阈值会在不同序列长度上导致差异巨大的稀疏度。

自动校准流程。为了找到给定目标稀疏度$S$的最佳$a$值,我们提出了算法2中详述的校准程序。该过程包括为几个上下文长度$\{L_k\}$经验性地找到能达到目标稀疏度$S$(在容差$\delta$范围内)的最佳拟合阈值$\lambda_{best}$。然后,我们对转换后的数据点$(1/L_k, \lambda_{best})$进行线性回归,以找到斜率$a$,从而定义我们的校准函数$\lambda(L) = a/L$。

算法 2 BLASST 校准
输入: 目标稀疏度 S, 校准数据集 D, 上下文长度 {Lk} k=1..K, lambda集合 Λ, 容差 δ
输出: 校准参数 a
1: 初始化数据点 P = ∅
2: for each 上下文长度 Lk do
3:    从 D 中采样长度为 Lk 的序列
4:    初始化 λ_best = None, min_gap = ∞
5:    for each λ ∈ Λ do
6:       s = MeasureSparsity(λ, Lk)
7:       gap = |s - S|
8:       if gap < min_gap then
9:          λ_best = λ
10:         min_gap = gap
11:      end if
12:   end for
13:   if min_gap < δ then  ▷ 仅当稀疏度足够接近时保留
14:      将 (1/Lk, λ_best) 添加到 P
15:   end if
16: end for
17: 使用 P 拟合线性回归: λ = a · (1/L)
18: return 回归系数 a

可预测的计算加速。更重要的是,通过设定固定的稀疏度水平,我们的校准确保了在不同上下文长度下可预测的计算加速。这对于要求性能一致的生产部署来说是一个至关重要的特性。

3.3 稀疏感知训练

作为后训练推理优化的扩展。虽然BLASST主要设计为一种后训练的推理优化方法,我们探索了稀疏感知训练作为一个简单的扩展,以进一步改善准确性与稀疏度的权衡。其动机很直接:如果模型在训练期间学会将重要信息集中在高分数的注意力块中,那么在推理时剪枝这些块时,它们应该能保持更高的准确性。

实现方法简单。我们的方法很简单:在微调期间,我们在前向传播中应用BLASST,根据阈值标准跳过可忽略的注意力块。在反向传播中,被跳过的块自然不会接收到梯度,因为它们在前向传播中没有被计算。这鼓励模型调整其注意力模式,使其更兼容稀疏性,将重要信息集中在能通过阈值测试的块中。这种方法不需要更改架构或引入辅助损失——它只是在训练时使用与推理时相同的稀疏注意力。

4 内核设计

BLASST内核的设计有两个主要目标:(1)对现有的FlashAttention内核接口和实现结构做最小的改动,以及(2)为块跳过决策逻辑引入最小的开销。我们的关键洞察是重用标准FlashAttention算法中已经计算的统计数据——具体来说,是在线softmax期间每个线程中都存在的局部最大值和运行最大值。

跳过决策的实现。决策过程(算法1中的第7行)每个块只需要几个额外的指令:(1)根据阈值比较为每个线程设置一个谓词(predicate),(2)发出一个VOTE指令来确定一个warp内的所有线程是否都同意跳过,以及(3)由每个warp中的一个线程向共享内存发出一个单一的ATOMIC指令,以协调softmax warpgroup内的块级决策。我们精心设计了内核,使得决策指令能够被现有操作隐藏,从而增加了可忽略的延迟开销。

针对prefill和decode的专门优化。由于prefill和decode阶段具有根本不同的性能特征,我们为每个阶段实现了专门的优化。

(a) 正常的FlashAttention prefill流水线调度。
图3 (a). 正常的FlashAttention prefill流水线调度。

(b) BLASST prefill流水线调度,T0和T1都跳过了循环1和3。
图3 (b). BLASST prefill流水线调度,其中T0和T1都跳过了循环1和3。

图3. FlashAttention和BLASST在4个循环迭代(L0-L3)中以50%稀疏度的prefill流水线调度。行根据warp/warpgroup的专门化进行分隔。较深和较浅的色调对应于不同瓦片行(T0/T1)的操作。MMA warp的BMM1和BMM2操作用B1和B2表示。softmax warpgroups主要受限于指数运算(EX2),但它们也执行跳过检查、行求和和softmax缩放(未显示)。主循环迭代由实线框起。

4.1 Prefill内核:计算密集型优化

Prefill内核通常是计算密集型的。其瓶颈在于CUDA核心(用于softmax)和张量核心(用于矩阵乘法)的吞吐量,而不是内存带宽。因此,我们的prefill内核旨在为被剪枝的块跳过softmax计算和MMA操作(注意力-值乘法)。

流水线调度的变化。图3展示了我们对BLASST prefill内核流水线调度的改动,该内核通过重叠不同的计算任务来为计算密集型场景进行优化。流水线调度了张量核心(数学warp/矩阵乘法)和CUDA核心(softmax和校正逻辑)上的操作。图3b显示,即使所有$QK^T$(BMM1)操作都被计算,内核也会动态跳过被识别为可忽略的块的计算密集型softmax和注意力-值乘法(BMM2)(例如图3b中的循环1和循环3)。通过跳过这些计算操作,内核释放了执行单元,使得后续操作可以更早地被调度。这压缩了整个调度,将总运行时间从图3a中的18个时间单位减少到图3b中的14个单位。

Value块在prefill内核中仍然从HBM加载。原因有三:(1) 内存带宽不是瓶颈;(2) 预取流水线得益于可预测的内存访问模式;以及(3) 条件性Value加载的延迟将超过节省的时间。通过专注于消除计算操作,我们在计算密集型场景中实现了几乎与稀疏度成线性比例的加速。我们当前的设计优先考虑了现代GPU上prefill是计算密集型的常见情况;然而,如果未来的工作负载或硬件架构转向内存带宽密集型,也可以在prefill中跳过Value加载。

(a) 正常的FlashAttention decode流水线调度。
图4 (a). 正常的FlashAttention decode流水线调度。

(b) BLASST decode流水线调度,跳过了循环1、2和4。
图4 (b). BLASST decode流水线调度,当跳过循环1、2和4时的情况。

图4. FlashAttention和BLASST在跳过循环1、2和4时的decode流水线调度。未显示序言部分,我们关注前6个循环迭代(L0-L5)的稳态。我们分开了TMA warp的流水线阶段,以显示如何一次性发出多个TMA加载。图4b中的加载完成得更快,因为同时进行的加载较少。箭头表示BMM1之后跳过检查的记分板依赖关系。请注意,MMA warp的BMM1和BMM2操作用B1和B2表示。

4.2 Decode内核:内存密集型优化

Decode内核通常是内存密集型的。其瓶颈在于获取KV缓存所需的HBM带宽,而不是计算,因为注意力只涉及单个Query与所有Key的交互。因此,我们的内核专注于为被剪枝的块跳过内存密集型的Value矩阵$V_j$加载,直接解决这个HBM瓶颈。此优化根据稀疏度水平按比例减少内存流量,同时将阈值和Key操作与剩余的Value加载重叠,以实现显著的加速,这反映了decode与prefill的不同性能特征。

流水线调度的变化。图4展示了我们对BLASST decode内核流水线调度的改动。加载K和V瓦片的长时间显示了内核在这种内存密集型场景下的行为。通过为循环1、2和4跳过V瓦片加载和BMM2,GPU可以更快地完成来自其他TMA流水线阶段的未完成加载。结果,图4a需要30个时间单位来完成所有V加载,而图4b则需要23个单位。

针对计算密集型解码的额外优化。对于像多头潜在注意力(Multi-head Latent Attention, MLA)【【17】Liu, A., Feng, B., Wang, B., Wang, B., Liu, B., Zhao, C., Dengr, C., Ruan, C., Dai, D., Guo, D., et al. Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model. arXiv preprint arXiv:2405.04434, 2024a.】这样即使在解码阶段也更偏向计算密集型的注意力机制,我们还为被剪枝的块跳过softmax操作,从而在内存节省之外提供进一步的加速。

实验环境

5.1 实验设置

实验结果

5.2 主要结果

总体性能。表1展示了BLASST在约50%和约75%稀疏度下,在Llama-3.1-8B和Qwen3-8B模型及多种基准测试上的性能。BLASST不仅以极小的退化保持了准确性,甚至在某些情况下超越了密集基线。例如,在Qwen3-8B上,50%稀疏度时MATH500(96.23 vs 95.87)和AIME 2024(76.50 vs 75.00)的性能有所提升。这可能是因为在信息本身稀疏的长上下文任务中,剪枝起到了隐式去噪的作用;而在长文本生成推理任务中,跳过不重要的块有助于模型专注于核心推理链。

表1. BLASST在不同稀疏度下,跨所有模型和基准的性能。我们在Llama-3.1-8B和Qwen3-8B上评估了三种部署场景:仅prefill优化(长上下文任务:RULER, LongBench)、仅decode优化(推理任务:MATH500, AIME 2024, GPQA)和prefill+decode联合优化。结果显示,即使在约75%的稀疏度下,准确率下降也极小,偶尔还优于密集基线。
表1. BLASST在不同稀疏度下,跨所有模型和基准的性能。我们在Llama-3.1-8B和Qwen3-8B上评估了三种部署场景:仅prefill优化(长上下文任务:RULER, LongBench)、仅decode优化(推理任务:MATH500, AIME 2024, GPQA)和prefill+decode联合优化。结果显示,即使在约75%的稀疏度下,准确率下降也极小,偶尔还优于密集基线。

Prefill阶段比较。表2将BLASST与最先进的prefill优化稀疏注意力方法在Llama-3.1-8B上进行了比较。在RULER(4K-64K上下文)和LongBench上,BLASST在所有稀疏方法中取得了最佳的综合性能(RULER平均92.87,LongBench 31.8),与密集注意力(93.21,31.4)非常接近,且无需预计算。BLASST显著优于MInference(RULER 84.15)和FlexPrefill(RULER 87.72)。

表2. Llama-3.1-8B-Instruct在RULER和LongBench上的Prefill阶段比较。BLASST在所有稀疏注意力方法中表现最佳,与密集注意力非常接近,且无需预计算或代理分数。
表2. Llama-3.1-8B-Instruct在RULER和LongBench上的Prefill阶段比较。BLASST在所有稀疏注意力方法中表现最佳,与密集注意力非常接近,且无需预计算或代理分数。

Decode阶段比较。表3在Qwen3-8B上评估了BLASST在推理密集型任务上的表现。在约50%的稀疏度下,BLASST在所有基准测试上的性能均持平或超过了密集基线,同时保持了长上下文能力。与RocketKV等专注于KV缓存压缩的方法相比,BLASST在保持高准确性的前提下实现了计算优化。

表3. Qwen3-8B在多种推理和生成任务上的Decode阶段比较。BLASST在所有基准上,包括数学推理(MATH500, AIME 2024)、研究生水平科学(GPQA)和代码生成(LiveCodeBench),均持平或超过了密集基线,同时保持了长上下文性能(RULER, LongBench)。
表3. Qwen3-8B在多种推理和生成任务上的Decode阶段比较。BLASST在所有基准上,包括数学推理(MATH500, AIME 2024)、研究生水平科学(GPQA)和代码生成(LiveCodeBench),均持平或超过了密集基线,同时保持了长上下文性能(RULER, LongBench)。

5.3 GPU内核性能

我们在Blackwell (B200) 和 Hopper (H200) GPU架构上实现了高度优化的内核。表4和图5显示了在prefill和decode阶段,性能随稀疏度增加的扩展情况。所有加速比均与FlashAttention-3 BF16基线进行比较。

关键结果

表4. BLASST在B200 (Blackwell) GPU上,随稀疏度增加的prefill和decode加速情况。我们通过改变阈值(λ)来展示在不同稀疏度下的性能。Prefill配置:批大小148,1个Q头,1个KV头,32K序列长度,128头维度。Decode配置:批大小148,32个Q头,4个KV头,32K序列长度,128头维度。
表4. BLASST在B200 (Blackwell) GPU上,随稀疏度增加的prefill和decode加速情况。我们通过改变阈值(λ)来展示在不同稀疏度下的性能。Prefill配置:批大小148,1个Q头,1个KV头,32K序列长度,128头维度。Decode配置:批大小148,32个Q头,4个KV头,32K序列长度,128头维度。
BLASST Prefill在H200 (Hopper)上的加速情况
BLASST Prefill在H200 (Hopper)上的加速情况
图5. BLASST prefill在Hopper GPU (H200)上的加速情况
图5. BLASST prefill在Hopper GPU (H200)上的加速情况

5.4 校准结果

校准方法的有效性。表5证明了我们校准方法的有效性。对于50%的目标稀疏度,固定阈值方法产生的稀疏度从4K上下文的23%到64K上下文的75%不等,波动极大。相比之下,我们校准的$\lambda = a/L$方法将稀疏度维持在一个很小的范围内,与目标的平均误差仅为1.2%。这证实了我们的校准能够在不同序列长度下实现可靠、可预测的稀疏度控制。

表5. Llama-3.1-8B上,校准阈值与固定阈值在不同上下文长度下的稀疏度稳定性对比。我们的校准方法在不同上下文长度下保持了一致的稀疏度水平,而固定阈值则产生高方差。括号中的数值表示实际稀疏度与目标的偏差。
表5. Llama-3.1-8B上,校准阈值与固定阈值在不同上下文长度下的稀疏度稳定性对比。我们的校准方法在不同上下文长度下保持了一致的稀疏度水平,而固定阈值则产生高方差。括号中的数值表示实际稀疏度与目标的偏差。

5.5 稀疏感知训练结果

提升准确性-稀疏度权衡。图6表明,稀疏感知训练改善了RULER基准测试上的准确性-稀疏度权衡。在50%-75%的目标稀疏度范围内,与在后训练阶段应用稀疏性相比,经过稀疏训练的模型实现了显著更高的准确性,准确性下降幅度减少了高达1.7倍。这证实了模型可以通过训练来适应稀疏注意力模式。

图6. 稀疏感知训练推动了准确性-稀疏度的前沿。在训练期间激活BLASST进行微调的模型,在激进的稀疏度水平下比后训练应用稀疏性的模型保持更高的准确性。通过使用稀疏注意力进行训练,模型学会了将信息集中在高分值的块中,使其对剪枝更具鲁棒性。
图6. 稀疏感知训练推动了准确性-稀疏度的前沿。在训练期间激活BLASST进行微调的模型,在激进的稀疏度水平下比后训练应用稀疏性的模型保持更高的准确性。通过使用稀疏注意力进行训练,模型学会了将信息集中在高分值的块中,使其对剪枝更具鲁棒性。

5.6 消融研究

稀疏度分布分析。图7显示了Llama-8B在8K上下文上的稀疏度在不同层和注意力头之间的分布情况。我们观察到显著的异质性:不同层和头表现出不同的稀疏度水平。BLASST通过在所有层和头上应用相同的阈值,自然地适应了这种异质性,自动地在注意力更集中的地方进行更激进的剪枝,而在注意力更分散的地方保留更多的块。

图7. Llama-8B在8K上下文上的层和头之间的稀疏度分布。数据来自NIAH基准样本,阈值λ = 0.03。头级别和层级别的巨大差异为自适应阈值策略提供了动力。
图7. Llama-8B在8K上下文上的层和头之间的稀疏度分布。数据来自NIAH基准样本,阈值λ = 0.03。头级别和层级别的巨大差异为自适应阈值策略提供了动力。

与其他稀疏方法的组合。表6探讨了将BLASST与其他注意力稀疏技术相结合的效果。我们发现BLASST可以与prefill优化方法(XAttention)和KV缓存压缩方法(RocketKV)有效组合,且准确性下降很小,证明了BLASST作为端到端优化流程中灵活构建块的潜力。

表6. Qwen 8b上,BLASST与其他稀疏方法组合的性能。BLASST可以与prefill优化方法(XAttention)和KV缓存压缩方法(RocketKV)有效组合,提供了灵活的部署选项。括号中的数字表示与密集基线相比的变化。
表6. Qwen 8b上,BLASST与其他稀疏方法组合的性能。BLASST可以与prefill优化方法(XAttention)和KV缓存压缩方法(RocketKV)有效组合,提供了灵活的部署选项。括号中的数字表示与密集基线相比的变化。

极长序列长度。表7展示了BLASST在RepoQA基准上处理极长序列(16K和200K)的性能。在200K token时,BLASST以极小的准确率下降实现了很高的prefill稀疏度(约58%)。更长的上下文表现出更高的自然稀疏性,使得我们的方法在密集注意力变得不切实际的极端长度场景中越来越有效。

表7. RepoQA基准上极长序列的性能。我们评估了BLASST在16K和200K上下文长度的代码库理解任务上的表现,显示了prefill(P)和decode(D)阶段的稀疏度。
表7. RepoQA基准上极长序列的性能。我们评估了BLASST在16K和200K上下文长度的代码库理解任务上的表现,显示了prefill(P)和decode(D)阶段的稀疏度。

分块行重排序。我们研究了改变分块行处理顺序是否能提高剪枝准确性。通过优先处理包含最近token(局部窗口)的分块,可以更快地建立一个更接近全局最大值的运行最大值,从而做出更准确的跳过决策。图8显示,重排序的效果依赖于数据集(对FWE有提升,对VT影响不大),这证明了BLASST算法对不同处理顺序的鲁棒性,并展示了进行特定数据集优化的潜力。

图8. Llama 3.1 8B (ctx=8192)上,分块行重排序对准确性-稀疏度权衡的影响。我们比较了标准累积最大值(顺序处理分块)和重排序累积最大值(逆序处理分块)。VT和FWE基准的图表显示,在给定的稀疏度水平下,重排序对模型准确性的影响可以忽略不计。
图8. Llama 3.1 8B (ctx=8192)上,分块行重排序对准确性-稀疏度权衡的影响。我们比较了标准累积最大值(顺序处理分块)和重排序累积最大值(逆序处理分块)。VT和FWE基准的图表显示,在给定的稀疏度水平下,重排序对模型准确性的影响可以忽略不计。

极端稀疏度分析。图9显示了BLASST在更高稀疏度(70-90%)下的行为。与使用代理分数的XAttention相比,BLASST使用真实的softmax统计数据进行剪枝,表现出更平稳的准确性退化,使其更适合于对计算效率要求极高的激进稀疏度设置。

图9. RULER-16K上高稀疏度水平下的准确性-稀疏度权衡(Qwen3-8B)。与XAttention相比,BLASST显示出更稳定的退化,在激进的稀疏度设置下保持了更好的准确性。这显示了使用实际softmax统计数据与基于代理的重要性分数的有效性。
图9. RULER-16K上高稀疏度水平下的准确性-稀疏度权衡(Qwen3-8B)。与XAttention相比,BLASST显示出更稳定的退化,在激进的稀疏度设置下保持了更好的准确性。这显示了使用实际softmax统计数据与基于代理的重要性分数的有效性。

A5 结论

我们提出了BLASST,一种简单而有效的稀疏注意力方法,它通过重用在线softmax的统计数据来动态地剪枝注意力计算,无需预计算或代理分数。BLASST在几乎不损失准确性的情况下实现了超过50%的稀疏度,并在现代GPU上取得了高达1.6倍的加速,显著提高了长上下文推理的实用性。我们的自动校准和稀疏感知训练进一步增强了其鲁棒性和灵活性,为高效的长上下文Transformer提供了一个实用的基础。

展望未来,我们相信硬件感知的稀疏模式、通过训练学习稀疏性以及自适应混合方法的结合,将是释放未来智能体AI系统全部潜力的关键。