MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention
MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention
作者/机构: Huiqiang Jiang†, Yucheng Li♢†, Chengruidong Zhang†, Qianhui Wu, Xufang Luo, Surin Ahn, Zhenhua Han, Amir H. Abdi, Dongsheng Li, Chin-Yew Lin, Yuqing Yang, Lili Qiu. Microsoft Corporation, ♢University of Surrey
A1 主要贡献
核心问题: 随着大型语言模型(LLM)的提示长度不断增加,其推理的计算挑战成为广泛部署的重大障碍。由于注意力计算的二次方复杂度,一个8B参数的LLM在单张A100 GPU上处理1M token的提示(即预填充阶段)需要30分钟,其中自注意力计算的开销超过总预填充延迟的90%。现有的预填充加速方法在应用于长上下文LLM时,往往无法保持可接受的准确性或效率。
研究目标: 本文旨在通过一种动态稀疏注意力计算方法,显著加速长上下文LLM推理的预填充阶段,同时保持模型的高准确率,且无需重新训练或微调。
创新点:
本文提出了MInference(Million-tokens Inference),一种专为加速长序列处理预填充阶段设计的稀疏计算方法。其核心创新点如下:
1. 识别通用稀疏模式:通过广泛分析,在长上下文LLM的注意力矩阵中识别出三种独特的稀疏模式:A-shape(A形)、Vertical-Slash(垂直-斜线)和Block-Sparse(块稀疏)。这些模式可以被利用来进行高效的GPU稀疏计算。
2. 核函数感知的离线模式分配:提出一种核函数感知的搜索方法,离线为每个注意力头分配最优的稀疏模式及其参数,确保在给定的计算成本(FLOPs)下实现最高的注意力分数召回率。
3. 高效的在线动态稀疏索引构建:与使用固定掩码的先前研究不同,MInference在推理过程中根据分配的模式和具体输入,动态地构建稀疏索引。例如,对于Vertical-Slash头,使用部分查询和键向量来估计全局垂直线和斜线的重要位置;对于Block-Sparse头,则通过对查询和键向量进行均值池化来确定最重要的块。
4. 优化的GPU核函数:为上述三种稀疏模式开发了三个优化的GPU核函数。这些核函数基于动态稀疏编译器PIT、Triton和FlashAttention,实现了极其高效的动态稀疏注意力计算。
主要成果: MInference在处理1M token上下文时,可在单张A100上将预填充阶段的延迟最多降低10倍(从30分钟缩短至3分钟),同时在包括InfiniteBench、RULER、PG-19和Needle In A Haystack在内的多个下游任务上保持甚至超过了基线模型的准确率。
图 1: (a) 在LLaMA-3-8B-1M模型上,MInference在Needle In A Haystack测试中达到或超过基线水平。(b) 注意力权重,尤其是在长上下文LLM中,在128K上下文中表现出高达96.8%的稀疏性。我们提出的MInference利用动态稀疏注意力来加速长上下文LLM推理的预填充阶段,在1M上下文的单张A100上实现了高达10倍的加速。
A3 背景知识/关键Observation/设计原则
2.1 注意力是动态稀疏的
注意力权重的稀疏性。在预训练的LLM中,尤其是在长上下文场景下,注意力权重的稀疏性已被广泛记录【【索引28,Dynamic sparse attention for scalable transformer acceleration+2022+IEEE Transactions on Computers】,【索引46,Sparq attention: Bandwidth-efficient LLM inference+2024+Forty-first International Conference on Machine Learning】,【索引30,Deja vu: Contextual sparsity for efficient LLMs at inference time+2023+Proceedings of the 40th ICML】,【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】】。如图2b所示,对于一个大小为128k × 128k的注意力矩阵,仅保留前4k列就能召回96.8%的总注意力。换言之,尽管正在处理长序列,但每个token只关注有限数量的token。
图 2: (a) 预填充阶段的延迟分解。(b) 在128k上下文中,top-k (k=4096) 列能覆盖多少注意力分数。(c) 当重用另一个样本的top-k索引时,检索到的注意力分数减少,表明其动态性。可视化基于单张A100上的LLaMa-3-8B。
稀疏模式的动态性。另一方面,尽管注意力矩阵的稀疏性在不同输入中普遍存在,但稀疏模式的精确分布是高度动态的。也就是说,在自注意力中,给定位置的token只关注序列的一个子集,而它具体关注的token是高度依赖于上下文的,并且在不同提示之间有显著差异。这种动态性已在先前的研究中得到了数学证明【【索引26,On the expressive power of self-attention matrices+2021+ArXiv preprint】,【索引27,On the expressive flexibility of self-attention matrices+2023+Proceedings of the AAAI Conference on Artificial Intelligence】】。如图2c所示,如果我们将图2b中找到的前4k列应用到另一个128k的提示上,注意力的召回率将大幅下降至83.7%。
2.2 注意力稀疏性呈现模式
注意力稀疏模式的分类。尽管注意力矩阵的稀疏性分布是动态的,但先前的工作【【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】,【索引20,LM-infinite: Zero-shot extreme length generalization for large language models+2024+Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)】】已经表明它们在二维空间中表现出某些模式,例如空间聚类。通过我们对各种长度和任务的长上下文提示的分析,我们将这些注意力稀疏模式分为A-shape、Vertical-Slash(VS)和Block-Sparse模式,如图3a和图4所示。表1详细说明了这三种模式的特点和差异。
表 1: 不同稀疏模式的比较。
图 3: (a) 来自不同注意力头的注意力权重可视化。对于不同的提示和任务,同一头的模式相对一致,但稀疏索引是动态变化的。(b) 注意力矩阵中top-10最近的非零元素的距离。(c) 使用我们识别的模式的注意力召回率分布,其中核函数中的FLOPs指的是在GPU上进行稀疏注意力计算所需的实际FLOPs。这里,Vertical-Slash模式使用1x64的块大小,其他模式在GPU上使用64x64的块大小。所有可视化均基于LLaMA-3-8B-Instruct-262K【索引17,Llama-3 8b instruct gradient 4194k (v0.1)+2024+None】。
A-shape 模式。这类头的注意力权重集中在初始token和局部窗口上【【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】,【索引20,LM-infinite: Zero-shot extreme length generalization for large language models+2024+Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)】】,表现出相对较高的稳定性。
Vertical-Slash (VS) 模式。注意力权重集中在特定的token(垂直线)【索引35,Random-access infinite context length for transformers+2023+Thirty-seventh Conference on Neural Information Processing Systems】和固定间隔的token(斜线)上。此模式中垂直线和斜线的位置会随着上下文内容动态变化,并表现出一定的稀疏性,使其难以被局部窗口和A-shape模式所包含。
Block-Sparse 模式。这种稀疏模式是动态性最强的,分布也更分散。尽管具有动态性,但注意力权重仍保持着一些空间聚类的特性,我们将其识别为块稀疏模式。我们分析了128k提示中非零注意力权重与其top-k最近非零邻居之间的距离,如图3b所示。结果表明,跨层和跨头,最近非零值之间的距离通常集中在5左右,这表明注意力权重具有很强的空间聚类性。
模式的效率。这三种模式的关键在于我们可以利用它们在长上下文LLM中对注意力矩阵进行高效的稀疏计算。在图3c中,我们测试了我们识别的模式在有限的GPU计算成本(FLOPs)下检索注意力分数的效率。首先,为每个注意力头标记一种稀疏模式(详见§3.2)。然后,我们证明了我们的模式比其他稀疏方法【【索引46,Sparq attention: Bandwidth-efficient LLM inference+2024+Forty-first International Conference on Machine Learning】,【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】,【索引42,Fast attention over long sequences with dynamic sparse flash attention+2024+Advances in Neural Information Processing Systems】】明显更高效。具体来说,在相同的FLOPs下,我们的模式在注意力分数上实现了显著更高的召回率,这可能带来更好的准确性。例如,先前的Top-K方法【【索引46,Sparq attention: Bandwidth-efficient LLM inference+2024+Forty-first International Conference on Machine Learning】,【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】,【索引42,Fast attention over long sequences with dynamic sparse flash attention+2024+Advances in Neural Information Processing Systems】】在处理块稀疏模式时遇到困难,因为它们全局关注特定的token,而我们的模式能更高效、更准确地检索注意力分数。我们在§3中详细说明了如何将这些模式应用于长上下文LLM,以及如何为这些模式实现优化的GPU核函数。
A2 方法细节
根据第2节的分析,我们提出了MInference来加速长上下文LLM的预填充阶段,该方法包括三个步骤:1) 离线识别每个头的注意力模式;2) 根据模式动态构建稀疏索引;3) 使用优化的GPU核函数进行稀疏注意力计算。
图 4: MInference中的三种稀疏方法。
3.1 问题形式化
稀疏注意力计算公式。当使用稀疏注意力计算加速长上下文LLM的预填充阶段时,注意力矩阵可以形式化如下:
$$A(M)=\text{Softmax}(\frac{1}{\sqrt{d}}QK^{\top}-c(1-M)),$$其中$M_{i,j} \in \{0, 1\}$表示注意力矩阵中(i, j)项的动态稀疏掩码。这里,$c$是一个大常数,例如1e5,确保对于$M_{i,j} = 0$的不重要注意力权重,在softmax之后其值趋近于零,即$A_{i,j} \approx 0$。
优化目标。动态稀疏注意力系统的目标是在尽可能多地保留注意力权重的同时,以最小的开销实现更大的加速。形式上,这可以表示为:
$$\min \quad |\boldsymbol{A}(\boldsymbol{M})-\boldsymbol{A}_{\text{dense}}|,$$其中$t_{sparse}$和$t_{overhead}$分别代表动态稀疏注意力计算和近似动态稀疏模式估计所花费的时间。
3.2 通过动态稀疏注意力加速长上下文LLM推理
核函数感知的最优稀疏模式搜索。为了在有限的FLOPs预算下达到最佳准确性,我们提出了一种离线的核函数感知的最优稀疏模式搜索方法。在这一步中,我们确定每个注意力头将使用哪种稀疏模式,以及该模式在实际计算中的最优设置(例如,VS模式中的垂直/斜线条数;或BS模式中的top-k块数)。如算法1所示,我们首先基于目标FLOPs为每种模式创建搜索空间,确保所有潜在候选者(即具有不同设置的不同模式)具有相似的计算成本。这里的“核函数感知”意味着计算成本反映了GPU核函数中的真实FLOPs,而不是概念上的估计,这对于实现最优加速至关重要。
搜索过程。接下来,我们使用一个参考样本遍历搜索空间,以确定最优的模式和设置。具体来说,我们使用注意力输出的召回率作为搜索最佳模式的目标标准。这种方法利用FlashAttention【索引11,Flashattention-2: Faster attention with better parallelism and work partitioning+2024+The Twelfth International Conference on Learning Representations】来减少GPU内存开销,并融合了V矩阵的信息,从而能够端到端地选择最佳模式,进一步提升性能。
算法 1 核函数感知的稀疏模式搜索
输入: Q, K, V ∈ R^(S×d_h), 模式 p,
搜索空间 ρ, 目标 FLOPs t,
初始化的搜索空间 σ
# 构建核函数感知的搜索空间
for i ← 1 to |σ| do
t_i ← FLOPs_in_kernel(σ_i)
while |t_i − t| > ϵ do
σ_i ← ChangeSpace(σ_i, p_i)
t_i ← FLOPs_in_kernel(σ_i)
end while
ρ ← ρ ∪ σ_i
end for
# 搜索最优的头模式
p_best ← ϕ
y ← Softmax(QK^⊤ / √d)
for i ← 1 to |ρ| do
y_i ← SparseAttention(QK^⊤ / √d, ρ_i)
p_best ← argmin(y_i − y, p_best)
end for
return p_best
稀疏索引近似与动态稀疏注意力计算。在推理阶段,我们将根据分配的模式和具体输入,对注意力矩阵进行在线估计,以动态确定我们稀疏索引的空间分布。之后,我们使用优化的GPU核函数进行稀疏注意力计算。我们的核函数实现细节可以在附录C.4中找到。需要注意的是,对于A-shape头,稀疏掩码是静态的,因此没有构建动态掩码的开销,只需要进行稀疏计算。
(i)Vertical-Slash 头。如算法2所示,由于垂直线和斜线的连续性,我们对最后的查询向量$Q[-last\_q:]$和键向量$K$进行矩阵乘法,生成估计的注意力矩阵$A_b$,该矩阵反过来用于确定垂直线$i_v$和斜线$i_s$的索引。在获得垂直线和斜线的稀疏索引后,我们将它们转换为稀疏格式$i_{vs}$。利用这些稀疏索引,我们进行注意力权重和注意力输出的块稀疏计算。
(ii)Block-Sparse 头。根据算法3,对$Q$和$K$应用均值池化以分别获得$\bar{Q}$和$\bar{K}$。将这两个矩阵相乘得到估计的块级注意力权重$\bar{A}_b$。由于均值池化和矩阵乘法运算是可交换的,因此得到的注意力权重近似等于实际注意力权重在均值池化后的结果。这使我们能够以最小的开销近似实际注意力权重的块稀疏模式。同样,我们构建一个稀疏索引$i_b$,并用它来计算稀疏的注意力权重和注意力输出。
算法 2 Vertical-Slash 头
输入: Q, K, V ∈ R^(S×d_h), k_v, k_s ∈ N
# 近似垂直和斜线模式 (last_q = 64)
A_b ← softmax( Q[-last_q:]K^⊤ / √d + m_casual )
# top k_v 个垂直线的索引,垂直方向求和
i_v ← argtopk( sum_v(A_b), k_v )
# top k_s 个斜线的索引,斜线方向求和
i_s ← argtopk( sum_s(A_b), k_s )
# 构建稀疏注意力索引
i_vs ← sparseformat(i_v, i_s)
# 最终的动态稀疏注意力分数 (仅索引块)
A ← softmax( sparse(QK^⊤, i_vs) / √d )
# 稀疏混合分数和值
y ← sparse(AV, i_vs)
return y
算法 3 Block-Sparse 头
输入: Q, K, V ∈ R^(S×d_h), k_b ∈ N
# 近似块稀疏模式 (block_size = 64)
Q̄ ← MeanPooling(Q, block_size)
K̄ ← MeanPooling(K, block_size)
Ā_b ← softmax( Q̄K̄^⊤ / √d + m_casual )
# top k_b 个块的索引
i_b ← argtopk( Ā_b, k_b )
# 构建稀疏注意力索引
i_b ← sparseformat(i_b)
# 最终的动态稀疏注意力分数 (仅索引块)
A ← softmax( sparse(QK^⊤, i_b) / √d )
# 稀疏混合分数和值
y ← sparse(AV, i_b)
return y
A4 实验环境
模型架构:
实验使用了四种先进的长上下文LLM:
* LLaMA-3-8B-Instruct-262k
* LLaMA-3-8B-Instruct-1048k
* GLM-4-9B-1M 【索引19,Chatglm: A family of large language models from glm-130b to glm-4 all tools+2024+ArXiv preprint】
* Yi-9B-200K 【索引67,Yi: Open foundation models by 01. ai+2024+ArXiv preprint】
此外,还在Phi-3-Mini-128K【索引2,Phi-3 technical report: A highly capable language model locally on your phone+2024+None】和Qwen2-7B-128K【索引5,Qwen technical report+2023+ArXiv preprint】上测试了Needle in A Haystack任务。
在所有实验中均使用贪心解码(greedy decoding)以保证结果的稳定性。
硬件配置:
* GPU: 单张NVIDIA A100 GPU。
软件配置:
* 代码实现: 使用PyTorch提供了方法的简单自定义实现。
* 依赖库: 基于FlashAttention【索引11,Flashattention-2: Faster attention with better parallelism and work partitioning+2024+The Twelfth International Conference on Learning Representations】、Triton【索引55,Triton: an intermediate language and compiler for tiled neural network computations+2019+Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages】和动态稀疏编译器PIT【索引72,Pit: Optimization of dynamic sparse deep learning models via permutation invariant transformation+2023+Proceedings of the 29th Symposium on Operating Systems Principles】。
* 数据格式: 使用bfloat16进行延迟实验。
实现细节:
* A-shape模式:目标FLOPs t设置为1k个全局token和4k个局部窗口。
* Vertical-Slash模式:last_q设置为64。
* Block-Sparse模式:block_size设置为64。
数据集与评估指标:
* InfiniteBench【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】: 包含10个任务,如PassKey检索、问答、编码、对话和摘要等,平均上下文长度约为214K。
* RULER【索引21,Ruler: What’s the real context size of your long-context language models?+2024+ArXiv preprint】: 一个具有挑战性的长上下文基准,包含4大类13个复杂任务,如检索、多跳追踪、聚合和问答,提示长度可达128k。
* Needle In A Haystack【索引23,Needle in a haystack - pressure testing llms+2023+None】: 一个长上下文检索基准,测试LLM在高达1M token的上下文窗口中、信息位于不同位置时的性能。
* PG-19【索引48,Compressive transformers for long-range sequence modelling+2020+8th International Conference on Learning Representations, ICLR 2020】: 用于长上下文语言建模任务,提示长度可达100k。
基线方法:
实验包括了五种无需训练的稀疏注意力方法作为基线:
1. StreamingLLM【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】: 对应A-shape模式,使用1k全局token和4k局部窗口。
2. StreamingLLM w/ dilated【索引6,Longformer: The long-document transformer+2020+ArXiv preprint】: 使用1k全局token和8k带间隔的扩张注意力窗口。
3. StreamingLLM w/ strided【索引8,Generating long sequences with sparse transformers+2019+ArXiv preprint】: 使用1k全局token、2k局部窗口和4k带间隔的扩张注意力。
4. InfLLM【索引66,Infllm: Unveiling the intrinsic capacity of llms for understanding extremely long sequences with training-free memory+2024+ArXiv preprint】: 使用内存单元处理流式长序列,设置128个全局token和8k局部窗口。
5. Ours w/ static: 在Vertical-Slash和Block-Sparse头中使用静态稀疏索引。
A4 实验结果
任务性能评估
- InfiniteBench:如表2所示,MInference在InfiniteBench上取得了最佳的整体性能,与基线方法相比表现更优。值得注意的是,MInference在某些任务上的性能与原始的全注意力基线相当,甚至略有超越。它不仅在摘要、问答和代码等自然语言任务上表现良好,还在检索相关任务中保持了原始模型的性能,而StreamingLLM等基线方法在这些检索任务上表现不佳。
表 2: 不同方法在不同基础模型上于InfiniteBench上的性能【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】。
- RULER:如表3所示,MInference在RULER基准测试中,即使在复杂的多跳或聚合任务中,也能有效保持长上下文性能。在超过32K的测试长度上,它甚至优于原始的全注意力模型,在LLaMA-3-8B-262K和GLM-4-9B-1M模型上实现了32K和64K的有效上下文窗口(性能超过85%的上下文被认为是有效的)。
表 3: 不同模型和方法在RULER上的性能(%)【索引21,Ruler: What’s the real context size of your long-context language models?+2024+ArXiv preprint】,评估长度从4k到128k。
- 语言建模:在基于PG-19数据集的语言建模任务中,如图5所示,MInference的结果优于其他稀疏方法,并且与全注意力基线的差异最小。对于100K token的提示,其困惑度仅比全注意力高0.2,但分别比StreamingLLM在Yi-9B-200K和LLaMA-3-262K模型上低0.25和0.75。
图 5: 使用不同模型和方法在PG-19上的困惑度结果【索引48,Compressive transformers for long-range sequence modelling+2020+8th International Conference on Learning Representations, ICLR 2020】。
- Needle In A Haystack:如图1a和图6所示,MInference在1k到1M token的不同上下文窗口中,能有效保留处理不同位置信息的能力。相比之下,StreamingLLM和InfLLM等方法一旦关键信息超出了全局token和局部窗口的范围,性能就会急剧下降。
图 6: 在 LLaMA-3-8B-1M 上使用 StreamingLLM 【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】的 Needle In A Haystack 测试结果。
消融研究
- 组件贡献:如表2、3和4所示,消融研究证明了MInference中不同组件的贡献。
- 静态 vs. 动态:使用静态索引(
Ours w/ static)会导致LLM性能显著下降,尤其是在像KV检索这样的高度动态任务中,准确率几乎降至零,凸显了动态策略的必要性。 - 模式组合:移除三种模式中的任何一种都会导致不同程度的性能下降。例如,“仅A-shape”只能捕获局部窗口内的信息;“仅block-sparse”也会导致性能大幅下降;而“仅vertical-slash”虽然保留了大部分性能,但仍落后于完整版的MInference。
- 静态 vs. 动态:使用静态索引(
表 4: 使用LLaMA-3-8B-Instruct-262K在InfiniteBench上的不同消融方法的性能【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】。
效率与集成
- 延迟:如图1b和图10所示,在单张A100上,MInference在100K、300K、500K和1M tokens的上下文长度下,分别实现了1.8倍、4.1倍、6.8倍和10倍的加速。对于1M token的提示,它将预填充延迟从30分钟减少到3分钟。动态稀疏索引构建的开销约占5%-20%。
- 与KV缓存压缩方法集成:如表5所示,MInference与先进的KV缓存压缩方法SnapKV【索引29,Snapkv: Llm knows what you are looking for before generation+2024+ArXiv preprint】结合使用,证明了其兼容性。在大多数任务上,性能几乎没有变化,平均得分甚至略有提高,进一步展示了其作为长上下文LLM服务优化的实用价值。
表 5: 在解码阶段使用SnapKV【索引29,Snapkv: Llm knows what you are looking for before generation+2024+ArXiv preprint】时,不同方法在InfiniteBench【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】上的性能。
- 在更大型LLM上的扩展:如表6所示,MInference在LLaMA-3-70B等更大型LLM上也保持了强大的性能。在KV检索等动态任务中,其性能与全注意力相当甚至略有提升,而InfLLM等基线则表现不佳。
表 6: 使用LLaMA-3-70B-Instruct-262K在InfiniteBench【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】上不同方法的性能。
A7 补充细节:相关工作
稀疏注意力。由于注意力机制的二次方复杂度,许多先前的工作都集中在稀疏注意力上以提高Transformer的效率。这些方法包括静态稀疏模式、基于聚类的稀疏方法和动态稀疏注意力。
* 静态稀疏模式:包括滑动窗口【【索引22,Mistral 7b+2023+ArXiv preprint】,【索引2,Phi-3 technical report: A highly capable language model locally on your phone+2024+None】】、扩张注意力【【索引8,Generating long sequences with sparse transformers+2019+ArXiv preprint】,【索引52,Sparsebert: Rethinking the importance analysis in self-attention+2021+Proceedings of the 38th International Conference on Machine Learning, ICML 2021】,【索引13,Longnet: Scaling transformers to 1,000,000,000 tokens+2023+ArXiv preprint】】和混合稀疏模式【【索引6,Longformer: The long-document transformer+2020+ArXiv preprint】,【索引71,Big bird: Transformers for longer sequences+2020+Advances in Neural Information Processing Systems 33】,【索引25,Block pruning for faster transformers+2021+Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing】】。
* 基于聚类的稀疏方法:包括基于哈希【索引24,Reformer: The efficient transformer+2020+8th International Conference on Learning Representations, ICLR 2020】和基于kNN的方法【【索引49,Efficient contentbased sparse attention with routing transformers+2021+Transactions of the Association for Computational Linguistics】,【索引36,Dynamic memory compression: Retrofitting LLMs for accelerated inference+2024+Forty-first International Conference on Machine Learning】】。
以上所有方法都需要从头开始预训练模型,这使得它们无法直接作为插件用于现成的LLM。最近,有工作【【索引12,Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality+2024+Forty-first International Conference on Machine Learning】,【索引68,A unified implicit attention formulation for gated-linear recurrent sequence models+2024+ArXiv preprint】】将状态空间模型【【索引16,Efficiently modeling long sequences with structured state spaces+2022+The Tenth International Conference on Learning Representations, ICLR 2022】,【索引15,Mamba: Linear-time sequence modeling with selective state spaces+2023+ArXiv preprint】】和线性注意力【【索引24,Transformers are rnns: Fast autoregressive transformers with linear attention+2020+Proceedings of the 37th International Conference on Machine Learning, ICML 2020】,【索引51,Retentive network: A successor to transformer for large language models+2023+ArXiv preprint】】统一为结构化掩码注意力。此外,一些工作【【索引63,Spatten: Efficient sparse attention architecture with cascade token and head pruning+2021+2021 IEEE International Symposium on High-Performance Computer Architecture (HPCA)】,【索引28,Dynamic sparse attention for scalable transformer acceleration+2022+IEEE Transactions on Computers】,【索引46,Sparq attention: Bandwidth-efficient LLM inference+2024+Forty-first International Conference on Machine Learning】】利用注意力的动态性来动态预测稀疏模式。然而,这些方法通常在动态模式近似期间关注低秩隐藏状态或使用后统计方法来获得稀疏掩码,这在估计步骤中引入了大量开销,使其对长上下文LLM的用处不大。
扩展LLM的上下文窗口。最近的研究集中在扩展预训练LLM的上下文窗口,使LLM能够处理更复杂的现实应用【【索引22,Swe-bench: Can language models resolve real-world github issues?+2023+The Twelfth International Conference on Learning Representations】,【索引41,Generative agents: Interactive simulacra of human behavior+2023+Proceedings of the 36th Annual ACM Symposium on User Interface Software and Technology】】。这些方法可分为:1) 分阶段预训练【【索引37,Xgen-7b technical report+2023+ArXiv preprint】,【索引14,Data engineering for scaling language models to 128k context+2024+Forty-first International Conference on Machine Learning】】;2) 修改或插值位置嵌入【【索引44,Train short, test long: Attention with linear biases enables input length extrapolation+2022+The Tenth International Conference on Learning Representations, ICLR 2022】,【索引10,Extending context window of large language models via positional interpolation+2023+ArXiv preprint】,【索引43,Yarn: Efficient context window extension of large language models+2024+The Twelfth International Conference on Learning Representations】,【索引14,LongroPE: Extending LLM context window beyond 2 million tokens+2024+Forty-first International Conference on Machine Learning】】;3) 利用外部内存模块进行上下文存储【【索引4,Unlimiformer: Long-range transformers with unlimited length input+2023+Thirty-seventh Conference on Neural Information Processing Systems】,【索引56,Focused transformer: Contrastive training for context scaling+2023+Thirty-seventh Conference on Neural Information Processing Systems】,【索引66,Infllm: Unveiling the intrinsic capacity of llms for understanding extremely long sequences with training-free memory+2024+ArXiv preprint】】;4) 以分布式方式在多个设备上扩展计算【索引32,Ringattention with blockwise transformers for near-infinite context+2024+The Twelfth International Conference on Learning Representations】。然而,这些方法并未缓解长上下文处理中的高昂推理成本。
长上下文LLM推理。最近的研究【索引15,Challenges in deploying long-context transformers: A theoretical peak performance analysis+2024+ArXiv preprint】从预填充和解码两个角度解决了注意力的高计算成本和大量的KV缓存存储问题。预填充优化主要分为状态空间模型【【索引16,Efficiently modeling long sequences with structured state spaces+2022+The Tenth International Conference on Learning Representations, ICLR 2022】,【索引15,Mamba: Linear-time sequence modeling with selective state spaces+2023+ArXiv preprint】】、线性注意力方法【【索引51,Retentive network: A successor to transformer for large language models+2023+ArXiv preprint】,【索引38,RWKV: Reinventing RNNs for the transformer era+2023+Findings of the ACL: EMNLP 2023】】、基于内存的方法【【索引34,Leave no context behind: Efficient infinite context transformers with infini-attention+2024+ArXiv preprint】,【索引19,Block transformer: Global-to-local language modeling for fast inference+2024+ArXiv preprint】】、混合方法【【索引29,Jamba: A hybrid transformer-mamba language model+2024+ArXiv preprint】,【索引47,Samba: Simple hybrid state space models for efficient unlimited context language modeling+2024+ArXiv preprint】】和提示压缩方法【【索引27,Compressing context to enhance inference efficiency of large language models+2023+Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing】,【索引22,LLMLingua: Compressing prompts for accelerated inference of large language models+2023+Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing】,【索引21,LongLLMLingua: Accelerating and enhancing LLMs in long context scenarios via prompt compression+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】,【索引45,LLMLingua-2: Data distillation for efficient and faithful task-agnostic prompt compression+2024+Findings of the ACL 2024】】。然而,这些方法需要从头开始训练或有额外的开销,难以直接在预训练的长上下文LLM中实现。最近,一些研究【【索引33,Iceformer: Accelerated inference with longsequence transformers on CPUs+2024+The Twelfth International Conference on Learning Representations】,【索引66,Infllm: Unveiling the intrinsic capacity of llms for understanding extremely long sequences with training-free memory+2024+ArXiv preprint】,【索引24,Retrievalattention: Accelerating long-context llm inference via vector retrieval+2024+ArXiv】】专注于使用kNN或基于聚类的稀疏注意力来加速LLM推理。然而,这些方法通常会导致准确性降低、加速有限或仅限于CPU场景。
相比之下,解码阶段的优化分为:1) 重用注意力KV以减少KV缓存存储【【索引53,Fast transformer decoding: One write-head is all you need+2019+ArXiv preprint】,【索引2,Gqa: Training generalized multi-query transformer models from multi-head checkpoints+2023+None】,【索引52,You only cache once: Decoder-decoder architectures for language models+2024+ArXiv preprint】,【索引10,Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model+2024+None】,【索引36,Dynamic memory compression: Retrofitting LLMs for accelerated inference+2024+Forty-first International Conference on Machine Learning】】;2) 静态KV缓存丢弃【【索引65,Efficient streaming language models with attention sinks+2024+The Twelfth International Conference on Learning Representations】,【索引20,LM-infinite: Zero-shot extreme length generalization for large language models+2024+Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)】】;3) 动态KV缓存丢弃【【索引73,H2o: Heavy-hitter oracle for efficient generative inference of large language models+2024+Advances in Neural Information Processing Systems】,【索引28,Scissorhands: Exploiting the persistence of importance hypothesis for llm kv cache compression at test time+2024+Advances in Neural Information Processing Systems】,【索引18,Model tells you what to discard: Adaptive kv cache compression for llms+2024+The Twelfth International Conference on Learning Representations】,【索引38,Transformers are multi-state rnns+2024+ArXiv preprint】,【索引29,Snapkv: Llm knows what you are looking for before generation+2024+ArXiv preprint】,【索引3,Dynamic context pruning for efficient and interpretable autoregressive transformers+2024+Advances in Neural Information Processing Systems】】;4) 动态KV缓存卸载【【索引46,Sparq attention: Bandwidth-efficient LLM inference+2024+Forty-first International Conference on Machine Learning】,【索引12,Sequence can secretly tell you what to discard+2024+ArXiv preprint】,【索引57,QUEST: Query-aware sparsity for efficient long-context LLM inference+2024+Forty-first International Conference on Machine Learning】,【索引24,Retrievalattention: Accelerating long-context llm inference via vector retrieval+2024+ArXiv】,【索引9,Magicpig: Lsh sampling for efficient llm generation+2024+ArXiv preprint】,【索引50,Shadowkv: Kv cache in shadows for high-throughput long-context llm inference+2024+ArXiv preprint】】;5) 因KV缓存压缩导致性能损失的恢复方法【【索引1,Keyformer: Kv cache reduction through key tokens selection for efficient generative inference+2024+Proceedings of Machine Learning and Systems】,【索引14,Get more with LESS: Synthesizing recurrence with KV cache compression for efficient LLM inference+2024+Forty-first International Conference on Machine Learning】】;6) 分层推测解码方法【【索引50,Triforce: Lossless acceleration of long sequence generation with hierarchical speculative decoding+2024+First Conference on Language Modeling】,【索引9,Magicdec: Breaking the latency-throughput tradeoff for long context generation with speculative decoding+2024+ArXiv】】;7) KV缓存量化【索引31,KIVI: A tuning-free asymmetric 2bit quantization for KV cache+2024+Forty-first International Conference on Machine Learning】。然而,这些方法并未解决预填充阶段注意力计算的沉重负担。
A5 结论
本文解决了长上下文LLM预填充阶段中注意力计算成本高昂和延迟不可接受的问题。我们提出了MInference,一种通过利用具有空间聚合模式的动态稀疏注意力来加速预填充阶段的方法。具体来说,我们将注意力头分为三种类型:A-shape、Vertical-Slash和Block-Sparse。使用核函数感知的最优稀疏模式搜索方法,我们为每个头确定了最优模式。随后,我们利用快速近似方法为不同输入构建动态稀疏掩码,然后应用这些掩码执行稀疏注意力计算。在InfiniteBench、RULER、语言建模和Needle In A Haystack等基准上的实验结果表明,我们的方法有效保持了LLM的长上下文能力,同时实现了高达10倍的加速,将100万token提示在单个A100 GPU上的延迟从30分钟减少到3分钟。此外,我们发现类似的多模态LLM【索引62,Look-m: Look-once optimization in kv cache for efficient multimodal long-context inference+2024+ArXiv preprint】和编码器-解码器LLM【索引49,Exploring the limits of transfer learning with a unified text-to-text transformer+2020+J. Mach. Learn. Res.】中也存在类似的动态稀疏注意力模式。使用MInference进行预填充阶段的推理加速具有巨大的潜力。
A6 附录
A 局限性
短上下文下的开销。随着上下文长度的减少,构建动态索引所需的时间变得更加显著,因为注意力计算时间减少了。例如,在10k上下文中,用于构建索引的时间从5%增加到30%,导致整体端到端延迟接近FlashAttention。然而,随着提示变长,这个开销比例会逐渐减小。此外,当使用更高的稀疏率时,模型性能可能会明显下降。
B 更广泛的影响
推动长上下文LLM的应用。MInference有效加速了长上下文LLM的推理,促进了它们的部署和应用。通过实现更低的延迟,它可以降低LLM的部署成本,特别是对于长上下文LLM,有助于普及先进AI技术。它还促进了相关领域的进一步研究和发展。
C 实验细节
C.1 数据集细节
- InfiniteBench【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】:包括10个任务,旨在测试长上下文处理的各个方面。具体任务包括整本小说摘要、基于小说的开放式问答、小说多项选择问答、长剧本问答、中文文本问答、大型代码库调试、识别数组中最大/最小数字,以及不同模式长度的检索任务。这些任务的平均token长度为214k,共包含3,992个样本。
- RULER【索引21,Ruler: What’s the real context size of your long-context language models?+2024+ArXiv preprint】: 最近推出的用于长上下文评估的合成基准套件,包含4个类别下的13个复杂任务。检索类别包括单针大海捞针(S-NIAH)、多键大海捞针(MK-NIAH)、多值大海捞针(MV-NIAH)和多查询大海捞针(MQ-NIAH)。多跳追踪类别包括变量追踪(VT)。聚合类别引入了常见词提取(CWE)和高频词提取(FWE)。问答(QA)类别通过添加干扰段落来扩展现有的短上下文QA数据集。这些任务全面评估了长上下文建模能力。我们按照【索引21,Ruler: What’s the real context size of your long-context language models?+2024+ArXiv preprint】的方法,在4K、8K、16K、32K、64K和128K的上下文长度上测试模型,每个长度包含2,600个样本。
- Needle In A Haystack【索引23,Needle in a haystack - pressure testing llms+2023+None】:通过在大量复杂文本(“草堆”)中嵌入特定的目标信息(“针”),评估检索增强生成(RAG)系统的性能。该测试评估语言模型在海量数据中识别和利用特定信息的能力。我们在这里将此任务扩展到1M上下文长度,包括750个样本。
- PG-19【索引48,Compressive transformers for long-range sequence modelling+2020+8th International Conference on Learning Representations, ICLR 2020】:长文本的困惑度也常被研究人员用来评估长上下文LLM的语言建模性能。PG-19是一个合适的测试集,因为它包含长达500K token的文本。我们的实验在PG-19中随机抽取的1000个长度超过100K token的样本上进行。
C.2 更多实现细节
-
模型:实验基于多个先进的长上下文LLM:
- LLaMA-3-8B-Instruct-262k:一个LLaMA-3变体,经过进一步的NTK-aware插值和使用Ring Attention的少量微调。
- LLaMA-3-8B-Instruct-1048k:与262k版本类似,但支持高达1M token的上下文。
- Yi-9B-200K【索引67,Yi: Open foundation models by 01. ai+2024+ArXiv preprint】:平衡了长上下文性能和通用能力的SOTA LLM。
- Phi-3-Mini-128K【索引2,Phi-3 technical report: A highly capable language model locally on your phone+2024+None】:由LongRoPE【索引14,LongroPE: Extending LLM context window beyond 2 million tokens+2024+Forty-first International Conference on Machine Learning】支持,提供高达128K上下文窗口。
- Qwen2-7B-128K【索引5,Qwen technical report+2023+ArXiv preprint】:Qwen系列最近发布的更新,支持高达128K上下文窗口。
- GLM-4-9B-1M【索引19,Chatglm: A family of large language models from glm-130b to glm-4 all tools+2024+ArXiv preprint】:上下文窗口提升至1M。
-
核函数实现:我们的核函数实现是基于动态稀疏编译器PIT【索引72,Pit: Optimization of dynamic sparse deep learning models via permutation invariant transformation+2023+Proceedings of the 29th Symposium on Operating Systems Principles】在Triton语言【索引55,Triton: an intermediate language and compiler for tiled neural network computations+2019+Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages】中开发和优化的。
- 搜索空间:我们将目标FLOPs t设置为与A-shape模式中1k全局token和4k局部窗口token相同。ChangeSpace的步长设置为50,相应的搜索空间如表7所示。此外,我们仅使用一个来自KV检索合成数据的30k token输入的样本作为验证集,该样本在不同长度和领域表现出强大的泛化性和稳定性。搜索时间在单个A100上大约为15分钟。我们对LLaMA-3-8B-Instruct-262K模型和LLaMA-3-8B-Instruct-1M模型使用了相同的最优稀疏模式配置,具体分布如图11所示。
表 7: 核函数感知的最优头模式搜索空间。在此背景下,A-shape代表全局token和局部窗口数,Vertical-Slash代表垂直线和对角线的Top-K数量,Block-Sparse代表保留的Top-K块数。
C.3 单A100实现细节
内存优化。原始的PyTorch LLaMA模型实现,在提示超过50k token时,会在单张A100(80G)上导致内存不足错误。为了能够在单张A100上运行1M提示推理,我们实现了以下优化:
1. 张量拆分:我们按头拆分Attention,按序列维度拆分MLP。在计算是瓶颈的长上下文场景中,这种拆分使GPU利用率保持在100%,且拆分的开销可忽略不计。
2. 减少中间变量:我们通过移除注意力掩码并在核函数内直接实现因果掩码逻辑,最小化了中间变量的分配。
3. 消除不必要的计算:在长上下文场景中,只有提示阶段最后一个token对应的logits是有意义的。因此,我们只保留了最后一个token的LM Head线性层的计算。
C.4 核函数实现
C.4.1 块稀疏Flash Attention
实现。我们的Block-Sparse核函数实现基于FlashAttention核函数的Triton版本【索引56,Triton implementation of the flash attention v2+2023+OpenAI】。以选定的块索引作为额外输入,每个线程块循环遍历一行中的top-K个块。正如FlashAttention【索引11,Flashattention-2: Faster attention with better parallelism and work partitioning+2024+The Twelfth International Conference on Learning Representations】中所讨论的,块稀疏FlashAttention核函数的延迟与块的数量呈线性关系,其加速比(与密集FlashAttention核函数相比)近似为:
$$s_{p}=\frac{S}{2B\times k_{b}}$$C.4.2 Vertical-Slash Attention
双核函数实现。Vertical-Slash注意力包括两个自定义核函数:Vertical-Slash稀疏索引核函数和Vertical-Slash稀疏FlashAttention核函数。
图 7: 在摘要任务中使用LLaMA-3-8B的vertical-slash模式的动态稀疏掩码【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】。黄色区域表示计算部分。斜线使用64×64的块,而垂直线使用1×64的块。
稀疏索引核函数。算法4中的Vertical-Slash稀疏索引核函数为每一行块构建索引。由于一个斜线段可以被一个方块掩盖,我们的注意力掩码是块和列的混合,如图7所示。我们应用了一种点-范围双向合并算法,其中垂直索引被视为点,斜线索引根据行索引被转换为范围。输出包括两部分:合并的范围和独立的列索引,其中范围由块索引表示。为一行构建索引的时间复杂度为$O(k_v + k_s)$。
稀疏FlashAttention核函数。算法5中的Vertical-Slash稀疏FlashAttention核函数是块稀疏注意力核函数和PIT【索引72,Pit: Optimization of dynamic sparse deep learning models via permutation invariant transformation+2023+Proceedings of the 29th Symposium on Operating Systems Principles】稀疏注意力核函数的混合。PIT是一种通过置换不变变换将稀疏数据加载到密集计算块中的技术。一个线程块首先循环遍历前一节中描述的块索引(块部分),然后循环遍历按块大小分组的列索引(PIT部分)。这种混合核函数的延迟与块和列的总面积呈线性关系。
D 附加实验结果
D.1 Needle In A Haystack
更多模型结果。除了第4节中展示的LLaMA-3-Instruct-1M的结果外,我们还在图8中展示了使用InfLLM的LLaMA-3-Instruct-1M的结果,并在图9中展示了GLM4-9B-1M、Yi-9B-200K、Phi-3-Mini-128K和Qwen2-7B-128K的结果。与全注意力相比,使用MInference对理解不同上下文窗口和文档深度的语义信息的能力影响最小。在使用Yi-9B-200K和Phi-3-Mini-128K时,在100k上下文长度附近甚至有轻微的性能提升。
图 8: 在LLaMA-3-8B-Instruct-1M中使用InfLLM的Needle In A Haystack测试结果。
图 9: 使用GLM-4-9B-1M【索引19,Chatglm: A family of large language models from glm-130b to glm-4 all tools+2024+ArXiv preprint】、Yi-9B-200K【索引67,Yi: Open foundation models by 01. ai+2024+ArXiv preprint】、Phi-3-Mini-128K【索引2,Phi-3 technical report: A highly capable language model locally on your phone+2024+None】和Qwen2-7B-128K【索引5,Qwen technical report+2023+ArXiv preprint】的Needle In A Haystack【索引23,Needle in a haystack - pressure testing llms+2023+None】测试结果。
D.2 延迟分解
各模式的性能。图10显示了本文提出的三种注意力模式以及FlashAttention的微基准测试结果。可以看出,Vertical-Slash是三种模式中最慢的,但在1M上下文窗口下,它仍然比FlashAttention快13倍。A-shape比Vertical-Slash稍快,但在1M时,A-shape比Vertical-Slash慢50%。Block-Sparse是最快的,在1M上下文窗口下比FlashAttention快30倍。
图 10: 单个注意力核函数在单A100上三种模式和FlashAttention【索引11,Flashattention-2: Faster attention with better parallelism and work partitioning+2024+The Twelfth International Conference on Learning Representations】在不同上下文窗口下的延迟分解,包括动态稀疏近似和构建动态稀疏性的索引时间。在10k tokens时,四个核函数的延迟非常接近,都小于1ms。在1M tokens时,A-shape的延迟为164ms。
开销分析。动态稀疏模式的估计和索引构建时间分别约占Vertical-Slash和Block-Sparse模式总时间的5%-15%和25%。Block-Sparse的索引构建开销更高,主要是因为耗时的MeanPooling和块级矩阵乘法计算。此外,稀疏索引的内存开销相对较小,在1M上下文的LLaMA-3-8B模型中保持在160MB以内。
D.3 附加消融研究
Vertical-Slash模式的内部组件分析。为了进一步分析动态垂直线和斜线在Vertical-Slash模式中对稀疏计算的作用,我们引入了一组新的消融研究:1) Ours w/ only vertical,仅在Vertical-Slash模式中使用垂直线和top-1的斜线。2) Ours w/ only slash,仅使用斜线和top-1的垂直线。相应的top-K数量是根据核函数中的FLOPs转换后设置的。
如表8所示,仅使用垂直线会导致性能显著下降,特别是在检索任务中,性能类似于仅使用块稀疏。相比之下,仅使用斜线保留了大部分性能,但在像KV检索这样的高度动态任务中,性能进一步下降,平均性能比Ours下降了2.9%。
表 8: 使用LLaMA-3-8B-Instruct-262K在InfiniteBench【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】上的不同消融方法的性能。需要注意的是,由于核函数限制,我们必须保留至少一条垂直线和一条斜线。因此,“ours w/ only vertical”保留了top-1的斜线,“ours w/ only slash”保留了top-1的垂直线。
E 模式分布
模式分配的观察。图11显示了通过我们的搜索获得的最优头配置的分布。首先,大多数模式是Vertical-Slash模式(>90%)。然而,根据消融研究,仅使用Vertical-Slash模式会显著影响像KV检索这样的高度动态任务的性能。其次,Block-Sparse模式主要分布在几个中间到后期的层中,而A-shape模式则在中间层中发现。尽管最优模式在不同模型之间略有不同,但它们通常与这些观察结果一致。
模式的泛化性。此外,我们在实验中对两个版本的LLaMA使用了相同的配置,结果表明1M模型也表现得非常好,在Needle In A Haystack任务中取得了近乎完美的结果。这证明了最优稀疏模式的泛化性。
图 11: 三种稀疏头模式在不同模型中的分布。我们对LLaMA-3-8B-Instruct-262K和LLaMA-3-8B-Instruct-1M使用了相同的最优稀疏模式配置。
F 核函数中的稀疏度分布
实际计算稀疏度。如图12所示,显示了三种模式在实际核函数计算过程中的稀疏度分布。可以看出,当上下文窗口超过200k时,所有三种模式的实际稀疏度都超过了90%。即使考虑到20%的索引构建开销,这也确保了核函数实现了超过8倍的加速。此外,当上下文窗口超过500k时,相对于FlashAttention的稀疏度超过95%,理论加速比超过15倍。
图 12: 核函数在不同上下文窗口中的稀疏度分布,指的是块覆盖后实际计算的核函数比例,与使用带因果掩码的FlashAttention时的稀疏率进行比较。
G 这种动态稀疏注意力模式只存在于自回归LLM或基于RoPE的LLM中吗?
模式的普适性。在BERT【索引52,Sparsebert: Rethinking the importance analysis in self-attention+2021+Proceedings of the 38th International Conference on Machine Learning, ICML 2021】和多模态LLM【索引62,Look-m: Look-once optimization in kv cache for efficient multimodal long-context inference+2024+ArXiv preprint】中也发现了类似的垂直和斜线稀疏模式。此外,如图13所示,我们分析了T5中不同头的注意力模式分布。很明显,即使在双向注意力中也存在垂直和斜线稀疏模式。
多模态模型的潜力。最近的研究【索引62,Look-m: Look-once optimization in kv cache for efficient multimodal long-context inference+2024+ArXiv preprint】分析了多模态LLM中的稀疏注意力模式,揭示了在LLaVA【索引29,Visual instruction tuning+2024+Advances in neural information processing systems】和InternVL【索引10,Internvl: Scaling up vision foundation models and aligning for generic visual-linguistic tasks+2024+In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition】等模型中也存在垂直和斜线模式。使用MInference进行预填充阶段的推理加速具有巨大的潜力。
图 13: 使用Flan-UL2【索引54,UL2: Unifying language learning paradigms+2023+The Eleventh International Conference on Learning Representations】在摘要数据集【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】上T5风格的编码器注意力中的稀疏模式。
H 案例研究
摘要任务案例。表9展示了基于LLaMA-3-8B-262K模型,在InfiniteBench的EN.SUM任务(200K输入长度)上各种方法的生成性能比较。原始摘要提供了一个全面连贯的叙述。StreamingLLM的摘要虽然看起来连贯,但引入了原文中不存在的元素,导致了严重的事实错误。相比之下,我们提出的方法生成的摘要详细而连贯,与原始摘要相当,清晰地描绘了故事的主要事件和主题。
KV检索任务案例。表10比较了在Retrieve.KV任务(200K输入长度)上使用LLaMA-3-8B-262K模型的各种方法的性能。原始方法展示了完美的检索能力。StreamingLLM再次生成了看起来连贯真实但事实错误的预测。其他基线方法也显著失败,产生了重复或无意义的字符串。然而,我们的方法表现与原始方法相当,准确地检索并预测了两个样本的键值对。这展示了我们的方法在处理KV检索任务方面的卓越能力。
表 9: 使用LLaMA-3-8B-Instruct-262K在摘要任务【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】中不同方法的生成结果比较。
表 10: 使用LLaMA-3-8B-Instruct-262K在KV检索任务【索引69,∞Bench: Extending long context evaluation beyond 100K tokens+2024+Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)】中不同方法的生成结果比较。
算法 4:Vertical-Slash 稀疏索引核函数
输入: 垂直索引 i_v ∈ N^(k_v), 斜线索引 i_s ∈ N^(k_s)
# 排序垂直和斜线索引
i_v ← IncrementalSort(i_v)
i_s ← DescendingSort(i_s)
# 计算块数 (块大小 B)
N ← ⌈S/B⌉
# 初始化输出
块计数 c_blk ∈ N^N, 块索引 i_blk ∈ N^(N×k_s),
列计数 c_col ∈ N^N, 列索引 i_col ∈ N^(N×k_v)
# # 在 GPU 中并行化
for i ← 1 to N do
# 找到与该行相交的第一条斜线
j_s ← biset_left(i_s, i × B)
# 通过斜线索引定义范围
r_start ← (i − 1) × B − i[j_s]
r_end ← i × B − i[j_s]
$$ \text{\textbf{while} } s_v \le k_s \text{ \textbf{do}} $$
# ... 合并点(垂直索引)和范围(斜线索引)的逻辑 ...
$$\begin{aligned} \begin{array}{l} s \leftarrow r_{\text{start}} \\ \text{\textbf{while }} s < r_{\text{end}} \text{ \textbf{do}} \\ \quad \boldsymbol{c}_{\text{blk}}^{i} \leftarrow \boldsymbol{c}_{\text{blk}}^{i} + 1 \\ \quad \boldsymbol{i}_{\text{blk}}^{i, \boldsymbol{c}_{\text{blk}}^{i}} \leftarrow s \\ \quad s \leftarrow s + B \\ \text{\textbf{end while}} \end{array} \end{aligned}$$end for return c_blk, i_blk, c_col, i_col
算法 5:Vertical-Slash 稀疏 FlashAttention 核函数
输入: Q, K, V ∈ R^(S×d_h), 块计数 c_blk ∈ N^N, 块索引 i_blk ∈ N^(N×k_s),
列计数 c_col ∈ N^N, 列索引 i_col ∈ N^(N×k_v)
缩放 τ ← 1/√d_h
初始化 O ← (0)^(S×d_h) ∈ R^(S×d_h)
# # 在 GPU 中并行化
for i ← 1 to N do
Load Q_chip ← Q[i×B:(i+1)×B] ∈ R^(B×d_h)
初始化 O_chip ← (0)^(B×d_h) ∈ R^(B×d_h)
初始化 m^i ← (−inf)^B ∈ R^B
初始化 l^i ← (0)^B ∈ R^B
# # 循环遍历块索引:块稀疏 flash attention
for j ← 1 to c_blk^i do
块起始 s ← i_blk^[i,j]
Load K_chip ← K[s:s+B]
Load V_chip ← V[s:s+B]
S ← τ Q_chip K_chip^⊤ ∈ R^(B×B)
m_new^i ← max(m^i, rowmax(S)) ∈ R^B
S ← S - m_new^i
P ← exp(S)
l_new^i ← rowsum(P)
α ← exp(m^i − m_new^i)
l^i ← αl^i + l_new^i
O_chip ← αO_chip + P V_chip
end for
$$\begin{aligned}
\begin{array}{l}
j \leftarrow 0 \\
\textbf{while } j < c_{\text{col}}^j \textbf{ do} \\
\quad \boldsymbol{cols} \leftarrow \boldsymbol{i}_{\text{col}}^{i, j:j+B} \in \mathbb{N}^B \\
\quad \text{Load } \boldsymbol{K}_{\text{chip}} \leftarrow \boldsymbol{K}^{\boldsymbol{cols}} \in \mathbb{R}^{B \times d_h} \\
\quad \text{Load } \boldsymbol{V}_{\text{chip}} \leftarrow \boldsymbol{V}^{\boldsymbol{cols}} \in \mathbb{R}^{B \times d_h} \\
\quad \boldsymbol{S} \leftarrow \tau \boldsymbol{Q}_{\text{chip}} \boldsymbol{K}_{\text{chip}}^T \\
\quad \boldsymbol{S} \leftarrow \text{mask}(\boldsymbol{S}) \\
\quad \boldsymbol{m}_{new}^i \leftarrow \max(\boldsymbol{m}^i, \text{rowmax}(\boldsymbol{S})) \in \mathbb{R}^B \\
\quad \boldsymbol{S} \leftarrow \boldsymbol{S} - \boldsymbol{m}_{new}^i \\
\quad \boldsymbol{P} \leftarrow \exp(\boldsymbol{S}) \\
\quad \boldsymbol{l}_{new}^i \leftarrow \text{rowsum}(\boldsymbol{S})) \\
\quad \boldsymbol{\alpha} \leftarrow \exp(\boldsymbol{m}^i - \boldsymbol{m}_{new}^i) \\
\quad \boldsymbol{l}^i \leftarrow \boldsymbol{\alpha} \boldsymbol{l}^i + \boldsymbol{l}_{new}^i \\
\quad \boldsymbol{O}_{\text{chip}} \leftarrow \boldsymbol{\alpha} \boldsymbol{O}_{\text{chip}} + \boldsymbol{P} \boldsymbol{V}_{\text{chip}} \\
\quad j \leftarrow j + B \\
\textbf{end while}
\end{array}
\end{aligned}$$
# ... 循环遍历列索引: PIT 稀疏 flash attention 的逻辑 ... end for
💬 评论讨论
欢迎在这里分享您的想法和见解!