FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference
FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference
作者/机构: Xunhao Lai (Peking University), Jianqiao Lu (The University of Hong Kong), Yao Luo (ByteDance Inc.), Yiyuan Ma (ByteDance Inc.), Xun Zhou (ByteDance Inc.)
A1 主要贡献
大型语言模型(LLM)在处理长序列推理时面临计算挑战,特别是在注意力预填充(pre-filling)阶段,其计算复杂度随提示长度呈二次方增长。为了解决这一瓶颈,研究者们提出了稀疏注意力机制。然而,先前的方法存在局限性,它们要么依赖固定的稀疏注意力模式(如BigBird、LongLora),要么基于有限情况识别稀疏模式(如StreamingLLM、MInference)。这些方法缺乏灵活性,无法高效适应不同输入的需求,当输入数据的复杂性变化时,其注意力模式往往是次优的。
为了应对这些挑战,本文提出了FlexPrefill,一种新颖的灵活稀疏预填充注意力机制,旨在根据每个输入和注意力头的具体需求,实时地、动态地调整稀疏注意力模式和计算预算。该方法的核心创新体现在两个关键组件中:
- 查询感知稀疏模式确定(Query-Aware Sparse Pattern Determination):该组件首先将注意力头分为需要根据查询进行特定估计的多样化模式(Diverse pattern)和在不同查询间保持一致的结构化模式(Structured pattern)。通过测量估计的注意力得分分布与真实分布之间的Jensen-Shannon散度,该组件能够自适应地为每个注意力头选择最合适的稀-疏模式。
- 基于累积注意力的索引选择(Cumulative-Attention Based Index Selection):该组件根据不同的注意力模式,动态地选择需要计算的查询-键(query-key)索引,确保所选索引的注意力分数总和达到预定义的阈值。这使得计算预算能够根据不同注意力头的重要性进行自适应分配,同时保持模型的有效性。
通过这种方式,FlexPrefill能够根据输入提示自适应地优化每个注意力头的稀疏模式和稀疏率,从而在计算效率和模型性能之间取得有效平衡。在包括Meta-Llama-3.1-8B-Instruct、GLM-4-9B-Chat、Yi-9B-200K和Qwen2-7B-Instruct在内的多个先进LLM上的大量实验,以及在RULER和InfiniteBench等具有挑战性的长上下文基准测试上的评估,都表明FlexPrefill在速度和准确性方面均优于先前的方法,为长序列LLM推理提供了一个更灵活、更高效的解决方案。
图1:不同推理方法的性能和速度分析。(a) 不同推理方法在LLaMA、GLM、Yi和Qwen模型上的性能比较。(b) 不同方法在各种上下文长度下的注意力加速比比较。
A3 背景知识与关键观察
问题设定
在稀疏注意力机制中,给定序列长度为$L$,查询矩阵为$Q$,键矩阵为$K$。每个位置的注意力分数通过查询矩阵$Q$和键矩阵$K$之间的缩放点积计算得出,并由头维度$d$的平方根进行归一化。为了提高计算效率,稀疏注意力将计算限制在一个查询-键对的子集上,其索引构成集合$S = \bigcup_{i=1}^{L} S_i$,其中$S_i \subseteq \{(i, j) | 1 \le j \le i, j \in Z\}$。稀疏注意力机制的形式化表示如下:
$$A(Q, K, V, S) = \text{Softmax} \left( \frac{1}{\sqrt{d}} (Q \cdot K^\top + M_S) \right) \cdot V$$此处,$M_S$是一个基于$S$的稀疏注意力掩码,定义为:
$$\begin{aligned} M_{\boldsymbol{S}}[i, j]= \begin{cases}0, & \text { if }(i, j) \in \boldsymbol{S}, \\ -\infty, & \text { otherwise. }\end{cases} \end{aligned}$$目标
动态稀疏注意力系统的目标是在计算效率和注意力效果之间取得平衡。这可以被形式化地表达为一个多目标优化问题:
$$\min_{\boldsymbol{S}}\ | \boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}) - \boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}, \boldsymbol{S}) | \ ,$$其中,$A(Q, K, V)$表示全注意力计算结果,而$A(Q, K, V, S)$表示稀疏注意力的结果。公式(1)中的第一个目标旨在最小化稀疏注意力与全注意力结果之间的差异。第二个目标则力求最小化所选子集$S$的大小,这直接关系到计算效率。
注意力稀疏模式的可变性
在LLM的注意力机制中,我们观察到不同注意力头之间的注意力模式可能存在显著差异。图2a揭示了不同查询位置存在一种多样化模式(Diverse pattern)。这种可变性表明,采用“一刀切”的稀疏注意力方法可能是次优的。相反,图2b展示了一种在不同查询中更为一致和结构化的模式,呈现出清晰的结构化模式(Structured pattern),这与【索引28, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024a, CoRR】的报告类似,也被称为垂直-斜线模式(Vertical-Slash pattern)。在具有此类模式的注意力头中,注意力图的一个子集可能足以估算整个稀疏索引集。
(a) 表现出多样化模式的注意力图,其中被关注的键令牌在不同的查询位置上散布着独立的块。
(b) 表现出结构化模式的注意力图,例如注意力集中在垂直和斜线状结构上,这在查询中是一致的。
图2:不同注意力头中注意力模式的比较。多样化模式(a)显示了在查询位置上散布着独立块的注意力,而结构化模式(b)则表现出注意力集中在某些结构上。
不同样本需要自适应的稀疏率
先前的工作(如【索引28, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024a, CoRR】)利用离线搜索来确定一个固定的稀疏率,并将其统一应用于所有输入案例。然而,如图3所示,不同的输入提示表现出不同程度的复杂性,需要不同的稀疏率才能达到最佳性能。具有显著长程依赖的样本可能从较低的稀疏率中受益,而那些主要具有局部依赖的样本则可以用较高的稀疏率进行高效处理。为了解决上述问题,我们提出了一种动态调整每个样本的稀疏模式和稀疏率的方法,具体将在第3节介绍。
图3:针对不同样本复杂度的不同注意力头和层的自适应稀疏率。每个热图显示了在给定固定注意力分数覆盖率下,不同上下文长度和任务类型的稀疏率,颜色越深表示注意力计算越多。不同注意力头的稀疏度分布随着样本类型(a,b)和上下文长度(b,c)的不同而变化。
A2 方法细节
本节介绍我们用于优化稀疏注意力机制的方法。我们的方法包含两个关键组成部分:(1)查询感知稀疏模式确定 和(2)基于累积注意力的索引选择。然后,我们在(3)算法部分展示了FlexPrefill的整体稀疏注意力算法。
查询感知稀疏模式确定
动机与分类。我们对多样化模式(图2a)和结构化模式(图2b)的观察启发了一种自适应的稀疏注意力估计方法。我们将注意力头分为两种类型:查询感知(Query-Aware)头,旨在捕捉依赖于查询位置的可变稀疏模式;以及垂直-斜线(Vertical-Slash)头,用于处理LLM中常见的结构化稀疏注意力模式。我们提出一种方法,该方法能够动态决定是使用查询感知的注意力估计还是更一致的模式,从而平衡计算效率与准确性。
实现方法。考虑到计算完整全局注意力图的挑战,我们采用了一种务实的方法:
1. 代表性查询选择:我们选择最后block_size个查询向量,记为$\hat{Q}$,作为后续估计的代表性子集。
2. 分块注意力估计:通过对序列应用池化操作,我们计算两个分块注意力分布:
$$
\begin{aligned}
\bar{\boldsymbol{a}} & = \text{softmax} \left( \text{avgpool}(\hat{\boldsymbol{Q}}) \text{avgpool}(\boldsymbol{K})^\top / \sqrt{d} \right), \\
\hat{\boldsymbol{a}} & = \text{sumpool} \left( \text{softmax} \left( \hat{\boldsymbol{Q}} \boldsymbol{K}^\top / \sqrt{d} \right) \right),
\end{aligned}
$$
其中池化操作的核大小等于`block_size`,$\bar{a}$代表估计的分布,而$\hat{a}$代表真实的分布。
-
分布比较:我们使用Jensen-Shannon散度的平方根来量化这两个分布之间的差异:
$$D_{JS}(\bar{\boldsymbol{a}}, \hat{\boldsymbol{a}}) = \sqrt{JSD(\bar{\boldsymbol{a}} \| \hat{\boldsymbol{a}})} = \sqrt{\frac{1}{2}\left(D_{KL}(\bar{\boldsymbol{a}} \| m) + D_{KL}(\hat{\boldsymbol{a}} \| m)\right)} ,$$其中$m = \frac{1}{2}(\bar{a} + \hat{a})$是平均分布,而$D_{KL}(\cdot||\cdot)$表示Kullback-Leibler散度。我们采用Jensen-Shannon散度是因为它的对称性和有界性,这确保了更稳定和可靠的模式确定。
4. 自适应决策:我们将$D_{JS}$与一个预定义的阈值$\tau$进行比较。如果$D_{JS} < \tau$,则认为分块注意力估计能够充分近似真实分布。因此,我们为所有查询计算$\bar{a}$,从而实现查询感知的稀疏索引选择。否则,如果近似不充分,我们则退回到一种更保守的方法。具体来说,我们仅使用一部分查询进行后续的稀疏索引搜索,并使用垂直-斜线模式扩展到全局索引集。
优势。这种自适应方法使我们能够在适当的时候利用查询感知的注意力模式,从而可能捕捉到更多样化的模式。同时,当注意力估计不可靠时,它提供了一种回退机制,转而使用更一致的垂直-斜线模式。
基于累积注意力的索引选择
优化目标。在确定一个注意力头属于查询感知型还是垂直-斜线型之后,我们试图在保证效果的同时最小化计算量。我们的稀疏注意力机制旨在为每个查询位置$i$选择尽可能小的子集$S_i$,同时确保每个子集内归一化注意力分数的总和满足一个预定义的累积注意力阈值$\gamma$。这个目标可以形式化地表示为:
$$\min_{\boldsymbol{S}_{i}} \sum_{i=1}^{n}\left|\boldsymbol{S}_{i}\right| \quad \text { subject to } \sum_{(i, j) \in \boldsymbol{S}_{i}} \frac{\exp \left(\boldsymbol{Q}_{i} \cdot \boldsymbol{K}_{j}^{\top} / \sqrt{d}\right)}{\sum_{j^{\prime}=1}^{i} \exp \left(\boldsymbol{Q}_{i} \cdot \boldsymbol{K}_{j^{\prime}}^{\top} / \sqrt{d}\right)} \geq \gamma, \forall i \in[n],$$其中$S = \cup_i S_i$,而$|S_i|$表示所选索引子集的大小,$Q_i$和$K_j$分别是位置$i$的查询向量和位置$j$的键向量。关于优化公式(3)的详细理由在附录A中提供。
实现策略。为了实现这一目标,我们根据注意力头的类型采用不同的策略:
1. 查询感知(Query-Aware)头:
* 首先对查询和键进行平均池化,然后计算注意力,以此来估计分块注意力分数。
* 将这些块按降序排序,并依次选择块,直到它们归一化的注意力分数之和超过阈值$\gamma$,从而形成稀疏索引集$S$。
- 垂直-斜线(Vertical-Slash)头:
- 选择一个代表性的查询子集$\hat{Q}$并计算注意力分数,然后计算垂直线和斜线的平均分数。
- 对这些线进行排序,并按降序选择垂直线和斜线,直到它们累积的归一化注意力分数超过$\gamma$。
- 将选定的线扩展到整个注意力矩阵,形成稀疏索引集$S$。
总结。这种自适应方法根据每个注意力头的注意力模式来定制选择过程,使我们能够高效地确定用于稀疏注意力计算的稀疏索引集$S$。通过这样做,我们在计算效率和注意力准确性之间取得了平衡,确保我们将计算集中在最相关的令牌交互上,同时保持预定义水平的累积注意力分数。
算法
算法1:稀疏注意力。算法1展示了我们提出的高效稀疏注意力计算的整体流程。该算法以查询矩阵Q、键矩阵K、值矩阵V、稀疏模式阈值$\tau$和累积注意力阈值$\gamma$作为输入。它分为以下三个部分:
- 稀疏模式确定:算法2根据输入决定是使用查询感知模式还是回退到垂直-斜线模式。
- 稀疏索引选择:基于(i)中获得的注意力模式和给定的累积注意力阈值$\gamma$,通过算法4(查询感知)或算法3(垂直-斜线)获得每个注意力头需要计算的稀疏索引集S。
- 稀疏注意力计算:算法使用获得的稀疏索引为每个注意力头执行稀疏注意力计算,并返回最终的注意力结果。
算法 1 稀疏注意力
输入: Q, K, V ∈ R^(S×d_h), τ ∈ [0, 1], γ ∈ (0, 1)
# 根据阈值τ确定稀疏模式
pattern ← 稀疏模式搜索(Q, K, τ)
# 根据pattern和γ决定稀疏索引集S
if pattern == query_specific then
S ← 查询感知索引(Q, K, γ)
else if pattern == vertical_slash then
S ← 垂直斜线索引(Q, K, γ)
end if
# 计算最终的稀疏注意力输出
y ← A(Q, K, V, S)
return y
算法 2 稀疏模式搜索
输入: Q, K, τ
# 取一个代表性的查询子集
选择 Q̂ = Q[-block_size:]
# 计算估计的分块池化注意力ā和真实的分块池化注意力â
ā ← softmax(pool(Q̂)pool(K)ᵀ / √d)
â ← pool(softmax(Q̂Kᵀ / √d))
# 计算Jensen-Shannon散度
d_JS ← √JSD(ā||â)
# 决定是否使用查询特定注意力模式
if d_JS < τ then
pattern ← query_specific
else
pattern ← vertical_slash
end if
return pattern
算法 3 垂直斜线索引搜索
输入: Q, K, γ
# 计算完整注意力图的一个子集
 ← softmax(Q̂Kᵀ / √d),其中 Q̂ ⊂ Q
# 沿垂直和斜线方向求和并归一化注意力分数
a_v ← sum_vertical(Â) / Σ_{i,j} Â[i, j]
a_s ← sum_slash(Â) / Σ_{i,j} Â[i, j]
# 对垂直和斜线注意力分数进行排序
I_v ← argsort(a_v)
I_s ← argsort(a_s)
# 获取使分数总和超过γ的最小计算预算
K_v ← min{k : Σ_{i∈I_v[1:k]} a_v[i] ≥ γ}
K_s ← min{k : Σ_{i∈I_s[1:k]} a_s[i] ≥ γ}
# 选择垂直和斜线索引
S_v ← I_v[1 : K_v], S_s ← I_s[1 : K_s]
S ← S_v ∪ S_s
return S
算法 4 查询感知索引搜索
输入: Q, K, γ
# 使用池化后的Q和K计算估计的注意力分数
Q̄ ← pool(Q), K̄ ← pool(K)
Ā ← softmax(Q̄K̄ᵀ / √d)
# 展平并归一化注意力图
Ā ← flatten(Ā / Σ_{i,j} Ā[i,j])
# 对注意力分数进行排序
I_a ← argsort(Ā)
# 获取使分数总和超过γ的最小计算预算
K ← min{k : Σ_{i∈I_a[1:k]} Ā[i] ≥ γ}
# 获得最终索引集
S ← I_a[1:K]
return S
A4 实验环境
模型
实验中使用了四种在处理长上下文任务方面表现出色的先进大语言模型:
* LLaMA: Meta-Llama-3.1-8B-Instruct-128k (【索引17, The llama 3 herd of models, 2024, CoRR】)
* GLM: GLM-4-9B-Chat-1024k (【索引79, Chatglm: A family of large language models from GLM-130B to GLM-4 all tools, 2024, CoRR】)
* Yi: Yi-9B-200K (【索引76, Yi: Open foundation models by http://01.ai, 2024, CoRR】)
* Qwen: Qwen2-7B-Instruct-128k (【索引75, Qwen2 technical report, 2024, CoRR】)
其中,Yi是预训练模型,其他均为指令微调模型,并使用默认的聊天模板。
数据集
在两个为长上下文理解提供独特挑战的数据集上对模型进行评估:
* RULER (【索引26, RULER: what’s the real context size of your long-context language models?, 2024, CoRR】): 一个合成基准数据集,用于评估具有可定制序列长度和任务复杂度的长上下文LLM。它扩展了基本的“大海捞针”测试,并引入了多跳追踪和聚合等新任务类别。
* InfiniteBench (【索引83, ∞bench: Extending long context evaluation beyond 100k tokens, 2024c, CoRR】): 一个旨在测试LLM在广泛上下文中理解长依赖关系的基准数据集,平均令牌数为214k。它包含10个跨越不同领域的合成和真实世界任务。
硬件与软件配置
- 硬件: 实验在配备单张NVIDIA A100 GPU(80GB显存)的计算环境中进行。
- 软件与实现细节:
- 使用PyTorch实现了一个自定义流水线,并基于FlashAttention (【索引10, Flashattention-2: Faster attention with better parallelism and work partitioning, 2024, http://OpenReview.net】)来确保高效的长上下文注意力机制。
- 利用Triton (【索引63, Triton: an intermediate language and compiler for tiled neural network computations, 2019, ACM】)优化GPU加速计算的性能。
- 所有实验中
block_size均设为128。 - 选择最后的
block_size个查询向量作为代表性查询集$\hat{Q}$,用于稀疏模式确定和垂直-斜线稀疏索引选择,这允许代表性注意力分数只计算一次并在两个阶段中复用。 - 稀疏模式阈值$\tau$对所有模型均设为0.1。
- 累积注意力阈值$\gamma$根据模型调整:Yi-9B-200k和Qwen2-7B-Instruct模型设为0.9,其他模型设为0.95。
- 为确保所有注意力头正常工作,保留了每个查询块的第一个和最后一个键块,并要求每个注意力头至少计算1024个令牌。
- 所有实验均使用贪心解码以保持结果的一致性。
基线方法
将本文方法与三个强大的基线进行比较:
- FlashAttention (【索引10, Flashattention-2: Faster attention with better parallelism and work partitioning, 2024, http://OpenReview.net】): 一种高度优化的注意力机制,利用硬件优化和高效的内存访问模式。
- StreamingLLM (【索引70, Efficient streaming language models with attention sinks, 2024b, http://OpenReview.net】): 一种稀疏注意力机制,结合了全局注意力汇(attention sink)令牌和扩张的局部注意力窗口。配置为1000个全局令牌和8000个局部窗口(扩张间隔为1)。
- MInference (【索引28, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024a, CoRR】): 一种高效的稀疏注意力方法,采用三种不同的稀疏注意力模式,并根据原始论文推荐设置,使用离线搜索的注意力模式和动态生成的索引。
所有基线方法在预填充阶段使用稀疏计算,在解码阶段切换到密集计算,以确保公平比较。
A4 实验结果
主要结果
RULER基准测试
在RULER基准上的评估结果(见表1)显示,StreamingLLM的性能随着上下文长度的增加而显著下降,而MInference在某些模型上表现次优。相比之下,FlexPrefill在多种上下文长度下都能稳定地保持甚至提升模型的性能,同时还能加速计算过程。
InfiniteBench基准测试
在InfiniteBench挑战上的评估结果(见表2)进一步证实了FlexPrefill的有效性。该方法在检索和问答任务中保留了模型的大部分性能,同时在复杂的数学和编码任务中也保持了效能。
性能与延迟权衡
FlexPrefill利用在线搜索的特性,可以无缝应用于各种模型,并能灵活地在计算速度和性能之间进行调整。通过调整$\gamma$参数,可以平衡模型效果和输出延迟:降低$\gamma$会加速处理,而增加$\gamma$则能保持模型质量。如图4所示,与MInference和固定预算的垂直-斜线注意力相比,FlexPrefill在更低的延迟下实现了更好的性能。
消融实验
固定预算 vs. 动态预算
实验对比了固定和动态计算预算分配。结果(见图4)表明,动态设置能够带来更好的性能,并在推理速度和模型效果之间实现更优的平衡。
查询感知(Query-Aware)头阈值$\tau$
对不同$\tau$值的影响进行了评估。如图5所示,在适当的$\tau$值下启用查询感知注意力头可以在不增加计算量的情况下增强模型性能。然而,如果$\tau$设置得过大,一些注意力估计不准确的头会被错误地识别为查询感知头,可能导致模型性能下降。
最小预算限制
实验分析了每个注意力头所需的最小计算预算。研究发现,设置一个最小预算阈值可以提升性能,并防止具有极高稀疏率的注意力头失效。
其他发现
* 稀疏注意力块大小对模型效果影响不显著(见附录F.2)。
* 代表性查询集$\hat{Q}$的选择对查询感知模式的确定影响很小,但对垂直-斜线头的索引选择影响显著,这证明了使用最后block_size个查询向量作为$\hat{Q}$的合理性(见附录F.3)。
* 查询感知索引搜索的替代实现对模型性能影响可忽略不计(见附录F.4)。
* 限制注意力头的最大计算预算对不同模型的影响各异,表明某些模型的性能在计算量增加到一定程度后会饱和,存在进一步加速的潜力(见附录F.6)。
通过可视化不同层的稀疏模式分布(附录H)、查询感知头和垂直-斜线头的稀疏掩码(附录I)以及不同层和头的稀疏率分布(附录J),进一步分析了我们方法的灵活性和各组件的重要性。
A7 补充细节
长上下文大语言模型
长上下文LLM。持续的工程和算法设计进步不断推动长上下文大语言模型(LLM)的边界,使其能够处理日益复杂的任务。一些方法通过收集大量长文本数据集并持续进行预训练或微调来扩展模型的上下文长度(如【索引20, Data engineering for scaling language models to 128k context, 2024b, http://OpenReview.net】、【索引8, Longlora: Efficient fine-tuning of long-context large language models, 2024, http://OpenReview.net】、【索引71, Effective long-context scaling of foundation models, 2024, Association for Computational Linguistics】)。鉴于许多现代LLM采用基于RoPE的位置嵌入(【索引60, Roformer: Enhanced transformer with rotary position embedding, 2021, CoRR】),研究者引入了各种创新的位置嵌入技术来扩展模型长度,例如【索引51, Yarn: Efficient context window extension of large language models, 2024, http://OpenReview.net】、【索引15, Longrope: Extending LLM context window beyond 2 million tokens, 2024, http://OpenReview.net】和【索引85, Found in the middle: How language models use long contexts better via plug-and-play positional encoding, 2024d, CoRR】。此外,还提出了利用外部存储或检索增强的方法来提升长上下文处理能力(如【索引43, Landmark attention: Random-access infinite context length for transformers, 2023, CoRR】、【索引64, Focused transformer: Contrastive training for context scaling, 2023, URL】、【索引72, Retrieval meets long context large language models, 2024, http://OpenReview.net】)。
LLM推理加速
LLM推理加速。鉴于自注意力机制的时间复杂度与序列长度呈二次方关系,为长上下文加速推理对LLM至关重要。一些策略通过与硬件特性协同的算法来优化原始注意力计算,包括FlashAttention(【索引11, Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022, URL】、【索引10, Flashattention-2: Faster attention with better parallelism and work partitioning, 2024, http://OpenReview.net】、【索引56, Flashattention-3: Fast and accurate attention with asynchrony and low-precision, 2024, CoRR】)和RingAttention(【索引38, Ring attention with blockwise transformers for nearinfinite context, 2023, CoRR】、【索引6, Striped attention: Faster ring attention for causal transformers, 2023, CoRR】)。其他方法通过减少上下文长度来加速注意力计算,如基于检索的技术(【索引43, Landmark attention: Random-access infinite context length for transformers, 2023, CoRR】、【索引66, Augmenting language models with long-term memory, 2023, URL】、【索引72, Retrieval meets long context large language models, 2024, http://OpenReview.net】、【索引44, Leave no context behind: Efficient infinite context transformers with infini-attention, 2024, CoRR】)和基于压缩的方法(【索引34, Compressing context to enhance inference efficiency of large language models, 2023, Association for Computational Linguistics】、【索引29, Longllmlingua: Accelerating and enhancing llms in long context scenarios via prompt compression, 2024b, Association for Computational Linguistics】、【索引48, Llmlingua2: Data distillation for efficient and faithful task-agnostic prompt compression, 2024, Association for Computational Linguistics】)。解决注意力二次方复杂度的策略还包括使用循环机制来聚合信息(【索引82, Soaring from 4k to 400k: Extending llm’s context with activation beacon, 2024b, CoRR】、【索引7, Scaling transformer to 1m tokens and beyond with RMT, 2023, CoRR】、【索引42, ∞-former: Infinite memory transformer, 2022, Association for Computational Linguistics】),使用状态空间模型(【索引23, Efficiently modeling long sequences with structured state spaces, 2022, http://OpenReview.net】、【索引22, Mamba: Linear-time sequence modeling with selective state spaces, 2023, CoRR】、【索引36, Jamba: A hybrid transformer-mamba language model, 2024, CoRR】),以及探索创新的模型架构(【索引61, Retentive network: A successor to transformer for large language models, 2023, CoRR】、【索引50, RWKV: reinventing rnns for the transformer era, 2023, Association for Computational Linguistics】、【索引4, xlstm: Extended long short-term memory, 2024, CoRR】)。鉴于注意力机制的内在稀疏性,许多稀疏注意力方法被提出。许多方法关注固定的注意力模式,如移位稀疏注意力(【索引8, Longlora: Efficient fine-tuning of long-context large language models, 2024, http://OpenReview.net】)、注意力汇(sink attention)(【索引70, Efficient streaming language models with attention sinks, 2024b, http://OpenReview.net】)及其他方法(【索引9, Generating long sequences with sparse transformers, 2019, URL】、【索引5, Longformer: The long-document transformer, 2020, URL】、【索引77, Big bird: Transformers for longer sequences, 2020, URL】、【索引59, Sparsebert: Rethinking the importance analysis in self-attention, 2021, PMLR】、【索引14, Longnet: Scaling transformers to 1, 000, 000, 000 tokens, 2023, CoRR】)。现代LLM如Mistral(【索引27, Mistral 7b, 2023, CoRR】)和Phi-3(【索引1, Phi-3 technical report: A highly capable language model locally on your phone, 2024, CoRR】)也采用固定的稀疏注意力模式,包括滑动窗口注意力。【索引24, Lm-infinite: Zero-shot extreme length generalization for large language models, 2024, Association for Computational Linguistics】和【索引69, Infllm: Unveiling the intrinsic capacity of llms for understanding extremely long sequences with training-free memory, 2024a, CoRR】进一步扩展了注意力汇方法。此外,其他研究,如【索引37, On the expressive power of selfattention matrices, 2021, URL】、【索引39, Transformer acceleration with dynamic sparse attention, 2021, URL】和【索引28, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024a, CoRR】的研究,强调了模型中动态稀疏模式的存在。
解码阶段加速
解码阶段,存储过去键和值的KV缓存技术非常普遍,这也为LLM推理加速带来了独特的挑战。一些方法更高效地优化计算过程或KV缓存管理,如FlashDecoding(【索引12, Flash-decoding for longcontext inference, 2023, URL】、【索引25, Flashdecoding++: Faster large language model inference on gpus, 2023, CoRR】)和PagedAttention(【索引31, Efficient memory management for large language model serving with pagedattention, 2023, ACM】)。其他方法旨在通过量化(【索引40, Intactkv: Improving large language model quantization by keeping pivot tokens intact, 2024a, Association for Computational Linguistics】、【索引30, GEAR: an efficient KV cache compression recipe for near-lossless generative inference of LLM, 2024, CoRR】、【索引41, KIVI: A tuning-free asymmetric 2bit quantization for KV cache, 2024b, http://OpenReview.net】)、令牌合并(【索引45, Dynamic memory compression: Retrofitting llms for accelerated inference, 2024, http://OpenReview.net】、【索引35, Snapkv: LLM knows what you are looking for before generation, 2024, CoRR】)或KV丢弃(【索引84, H2O: heavy-hitter oracle for efficient generative inference of large language models, 2023, URL】、【索引16, Get more with LESS: synthesizing recurrence with KV cache compression for efficient LLM inference, 2024, http://OpenReview.net】)来减小KV缓存大小。此外,稀疏注意力方法也广泛应用于解码阶段。一些技术保留整个KV缓存,但仅使用一个子集进行稀疏注意力计算(【索引54, Sparq attention: Bandwidth-efficient LLM inference, 2024, http://OpenReview.net】、【索引62, QUEST: query-aware sparsity for efficient long-context LLM inference, 2024, http://OpenReview.net】、【索引81, Pqcache: Product quantization-based kvcache for long context LLM inference, 2024a, CoRR】),而其他方法则通过特定的稀疏注意力模式来减小KV缓存(【索引21, Model tells you what to discard: Adaptive KV cache compression for llms, 2024, http://OpenReview.net】、【索引70, Efficient streaming language models with attention sinks, 2024b, http://OpenReview.net】、【索引84, H2O: heavy-hitter oracle for efficient generative inference of large language models, 2023, URL】)。
A5 结论
本文介绍了FlexPrefill,一种新颖的、灵活的稀疏注意力机制,旨在为LLM中的长序列预填充(pre-filling)提供实时自适应能力。我们的方法由查询感知稀疏模式确定和基于累积注意力的索引选择两个核心组件构成,通过根据输入动态优化每个注意力头的稀疏模式和稀疏率,解决了以往方法的局限性。在多个前沿LLM和具有挑战性的长上下文基准上的广泛实验表明,FlexPrefill在显著提高计算效率的同时,能够持续保持甚至增强模型性能。我们方法的自适应特性使得在速度和准确性之间取得了更好的平衡,在各种场景下都优于先前的方法。随着LLM不断发展并处理日益复杂的长上下文任务,FlexPrefill为在有效管理计算资源的同时保持高性能提供了一个有前景的解决方案。未来的工作可以探索这种自适应方法在不同模型架构中的进一步优化和应用。
A6 附录
A 理论依据
优化目标的转换。在实践中,我们可以通过为所选子集$|S|$的大小设置一个容忍率$\gamma_S$来实现多目标优化(公式(1))。在此约束下,目标是最小化注意力差异$|A(Q, K, V) - A(Q, K, V, S)|$,以在降低计算成本的同时保持稀疏注意力机制的有效性。这可以表述为以下约束优化问题:
$$\begin{aligned} \begin{aligned} \min_{\boldsymbol{S}} \quad & |\boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}) - \boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}, \boldsymbol{S})| \\ \text{subject to} \quad & |\boldsymbol{S}| \le \gamma_S \\ & \boldsymbol{S} \subseteq \{(i, j) | 1 \le j \le i, i, j \in \mathbb{Z}\} \end{aligned} \end{aligned}$$从误差上界到优化目标。我们现在更深入地解释优化稀疏注意力目标(公式3)如何实现目标(公式4)。为简洁起见,我们关注单个查询位置的优化问题,省略下标$i$。该查询位置的标准注意力计算为:$A(Q, K, V) = \sum_{j=1}^L \frac{e^{x_j}}{\sum_{j'=1}^L e^{x_{j'}}} v_j$;稀疏注意力计算为:$A(Q, K, V, S) = \sum_{j \in S} \frac{e^{x_j}}{\sum_{j' \in S} e^{x_{j'}}} v_j$。设注意力分数为$a_j = \frac{e^{x_j}}{\sum_{j'=1}^L e^{x_{j'}}}$。全注意力与稀疏注意力计算之间的差异可以表示为:
$$\begin{aligned} \begin{aligned} \boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})-\boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}, \boldsymbol{S}) & =\sum_{j=1}^{L} \frac{e^{x_{j}} v_{j}}{\sum_{j^{\prime}=1}^{L} e^{x_{j^{\prime}}}}-\sum_{j \in S} \frac{e^{x_{j}} v_{j}}{\sum_{j^{\prime} \in S} e^{x_{j^{\prime}}}} \\ & =\sum_{j=1}^{L} a_{j} v_{j}-\frac{1}{\sum_{j \in S} a_{j}} \sum_{j^{\prime} \in S} a_{j^{\prime}} v_{j^{\prime}} \\ & =\sum_{j \notin S} a_{j} v_{j}-\left(\frac{1}{\sum_{j \in S} a_{j}}-1\right) \sum_{j^{\prime} \in S} a_{j^{\prime}} v_{j^{\prime}} \end{aligned} \end{aligned}$$令$a_S = \sum_{j \in S} a_j$。那么,对于$k \in S$有$a_k \le a_S$,对于$k \notin S$有$a_k \le 1 - a_S$。将这些界限代入公式5,我们得到全注意力与稀疏注意力之间绝对差异的一个上界:
$$\begin{aligned} \begin{aligned} \boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})-\boldsymbol{A}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}, \boldsymbol{S}) & \leq \sum_{j \notin S} a_{j}\left|v_{j}\right|+\left(\frac{1}{a_{S}}-1\right) \sum_{j \in S} a_{j}\left|v_{j}\right| \\ & \leq\left(1-a_{S}\right) \sum_{j \notin S}\left|v_{j}\right|+\left(\frac{1}{a_{S}}-1\right) a_{S} \sum_{j \in S}\left|v_{j}\right| \\ & =\left(1-a_{S}\right) \sum_{j=1}^{L}\left|v_{j}\right| \end{aligned} \end{aligned}$$原始问题与对偶问题。为了最小化误差的上界,我们可以在子集大小的约束下,最大化所选子集中归一化注意力分数$a_S$的总和:
$$\max_{S \subseteq[L]} \sum_{j \in S} \frac{\exp \left(x_{j}\right)}{\sum_{j^{\prime}=1}^{L} \exp \left(x_{j^{\prime}}^{\prime}\right)}, \quad \text { subject to } \quad|S| \leq \gamma_{S}$$通过应用对偶原理(详细推导见附录B),我们可以将原始目标(公式(7))转换为其对偶形式,即在归一化注意力分数总和的约束下最小化子集大小。给定一个容忍度$0 < \gamma_a < 1$,单个查询位置的对偶优化目标变为:
$$\min_{S \subseteq[L]}|S|, \quad \text { subject to } \quad \sum_{j \in S} \frac{\exp \left(x_{j}\right)}{\sum_{j^{\prime}=1}^{L} \exp \left(x_{j^{\prime}}\right)} \geq \gamma_{a}$$通过聚合所有查询位置$i \in [n]$的目标(公式(8)),即可得到所有查询位置的通用目标,如公式(3)所示。
B 对偶形式:平衡子集大小与注意力分数
推导过程。为了更好地理解原始目标与其对偶形式之间的关系,我们逐步推导对偶目标。为简洁起见,我们关注单个查询位置的优化问题,并省略下标$i$。
-
步骤1:二元变量。如果$j \in S$,则令$z_j=1$,否则$z_j=0$。原始目标是:
$$\max_{z_{j} \in\{0,1\}} \sum_{j=1}^{L} z_{j} \frac{\exp \left(x_{j}\right)}{\sum_{j^{\prime}=1}^{L} \exp \left(x_{j^{\prime}}\right)}, \quad \text { subject to } \quad \sum_{j=1}^{L} z_{j} \leq \gamma_{S}$$
-
步骤2:松弛二元变量。将变量松弛为连续变量:
$$\max_{0 \leq z_j \leq 1} \sum_{j=1}^{L} z_j \frac{\exp(x_j)}{\sum_{j'=1}^{L} \exp(x_{j'})}, \quad \text{subject to} \quad \sum_{j=1}^{L} z_j \leq \gamma_S$$
-
步骤3:拉格朗日公式。我们为原始问题构建拉格朗日函数:
$$L(z, \lambda, \mu, \nu)=-\sum_{j=1}^{L} z_{j} \frac{\exp \left(x_{j}\right)}{\sum_{j^{\prime}=1}^{L} \exp \left(x_{j^{\prime}}\right)}+\lambda\left(\sum_{j=1}^{L} z_{j}-\gamma_{S}\right)-\sum_{j=1}^{L} \mu_{j} z_{j}+\sum_{j=1}^{L} \nu_{j}\left(z_{j}-1\right)$$
其中,$\lambda \ge 0$是大小约束的乘子,$\mu_j \ge 0$是$z_j \ge 0$的乘子,$\nu_j \ge 0$是$z_j \le 1$的乘子。 -
步骤4:KKT条件。KKT条件给出:
- 平稳性:
$$\frac{\partial L}{\partial z_j} = -\frac{\exp(x_j)}{\sum_{j'=1}^{L} \exp(x_{j'})} + \lambda - \mu_j + \nu_j = 0$$ - 互补松弛性:$$\lambda(\sum_{j=1}^{L} z_{j}-\gamma_{S})=0$$$$\mu_{j} z_{j}=0$$$$\nu_{j}(z_{j}-1)=0$$
- 平稳性:
-
步骤5:最优解结构。从平稳性条件得出:
$$\begin{aligned} z_{j}=\begin{cases}1 & \text { if } \frac{\exp \left(x_{j}\right)}{\sum_{j^{\prime}=1}^{L} \exp \left(x_{j^{\prime}}\right)}>\lambda \\ 0 & \text { if } \frac{\exp \left(x_{j}\right)}{\sum_{j^{\prime}=1}^{L} \exp \left(x_{j^{\prime}}\right)}<\lambda\end{cases} \end{aligned}$$ -
步骤6:对偶性。对偶问题变为找到最小的$|S|$(或等价地$\sum_{j=1}^L z_j$)以达到所需的注意力质量$\gamma_a$。这给了我们:
$$\min_{0\le z_j \le 1} \sum_{j=1}^{L} z_j \quad \text{subject to } \sum_{j=1}^{L} z_j \frac{\exp(x_j)}{\sum_{j'=1}^{L} \exp(x_{j'})} \ge \gamma_a$$
对偶转换的理解。从原始形式到对偶形式的转换自然地源于拉格朗日函数的最优性条件。与大小约束相关的拉格朗日乘子$\lambda$实际上成为了注意力分数的一个阈值。在最优状态下,当且仅当一个位置的归一化注意力分数超过$\lambda$时,该位置才被选择(即$z_j=1$)。这自然地导出了对偶公式,其中$\gamma_a$对应于原始问题中实现的最优注意力质量。由于原始问题的凸松弛确保了零对偶间隙,因此等价性由强对偶性保证。对偶目标在确保所选子集$S$中归一化注意力分数之和大于或等于阈值$\gamma_a$的同时,最小化了该子集的大小。这与公式(3)中定义的目标完全相同。通过推导对偶形式并证明其与我们的主要目标等价,我们在公式(7)和公式(3)的表述之间建立了强有力的联系。
C 不同方法的详细延迟
延迟对比。为了突显我们方法的效率和性能优势,我们展示了在RULER数据集上的详细注意力计算延迟比较。具体来说,我们评估了长序列(64k和128k令牌)的单次注意力函数调用的平均延迟,以及所有序列长度的总体平均延迟。如表3所示,我们的FlexPrefill方法在保持较低延迟的同时实现了更好的性能,从而优于竞争方法。
Table 3: 不同方法在RULER数据集上各种模型和序列长度的延迟比较。
D 与其他基线的比较
扩展对比。为了进一步评估我们提出方法的效率,我们提供了与其他基线的性能比较。具体来说,我们使用Llama-3-8B-Instruct-262k(【索引49, Llama 3 gradient: A series of long context models, 2024, URL】)模型评估了LM-Infinite(【索引24, Lm-infinite: Zero-shot extreme length generalization for large language models, 2024, Association for Computational Linguistics】)、InfLLM(【索引69, Infllm: Unveiling the intrinsic capacity of llms for understanding extremely long sequences with training-free memory, 2024a, CoRR】)和MoA(【索引18, Moa: Mixture of sparse attention for automatic large language model compression, 2024a, CoRR】)。此外,我们使用Meta-Llama-3.1-8B-Instruct-128k(【索引17, The llama 3 herd of models, 2024, CoRR】)模型评估了HIP(【索引32, Hip attention: Sparse sub-quadratic attention with hierarchical attention pruning, 2024, CoRR】)。所有评估均在RULER数据集上进行。如表4所示,我们的方法在性能和推理速度上都持续优于这些基线。
Table 4: 在RULER数据集上,其他方法在各种模型和序列长度上的性能比较。
E 不同$\gamma$值下的性能-延迟权衡
权衡分析。在图4中,我们绘制了模型性能与单个注意力头平均预填充时间的关系图。全注意力模型的平均延迟报告如下:4.5毫秒(所有输入长度的平均值),20.58毫秒(输入长度为128k时),以及1.19毫秒(输入长度为32k时)。我们提出的方法根据参数$\gamma$表现出不同的延迟,同时保持了可比的性能。我们为不同的$\gamma$值提供了全面的比较,以便于经验性地选择。如表5所示,降低$\gamma$会导致更快的速度,但性能上会有轻微的权衡。在实践中,可以灵活调整$\gamma$以满足特定的性能或速度要求。
图4:我们的方法与MInference和固定预算的垂直-斜线注意力的比较,显示了模型性能和注意力延迟之间的权衡。在不同的注意力延迟下,我们的方法始终优于MInference和固定预算的垂直-斜线方法。更多细节见附录E。
Table 5: Llama-3.1-8B模型在不同γ值下的性能比较。
F 补充消融研究
F.1 查询感知头阈值
阈值$\tau$的影响。我们评估了不同$\tau$对模型性能的影响,并在表6中提供了模型在不同上下文长度下的详细性能。
Table 6: 在RULER数据集上,不同模型使用不同查询感知头阈值τ的性能比较。
F.2 Triton块大小
块大小的影响。我们探讨了调整Triton块大小的影响,具体比较了64和128的块大小。表7中的结果显示,块大小对模型性能没有显著影响,因此可以根据不同的硬件灵活选择不同的块大小。
Table 7: 在RULER数据集上,不同模型使用不同Triton block_size的性能比较。
F.3 代表性查询子集
查询子集位置的影响。我们对方法中使用的代表性查询子集进行了消融实验。我们将稀疏模式确定和垂直-斜线索引选择中使用的子集位置从序列末尾替换到中间。表8中的结果显示,替换用于稀疏模式确定的子集对模型性能没有显著影响。然而,替换用于垂直-斜线稀疏索引选择的子集会显著降低性能。这证明了选择最后block_size个查询向量作为代表性查询子集的合理性。
Table 8: Llama-3.1-8B-Instruct模型在RULER数据集上使用不同代表性查询集的性能比较,其中last表示使用最后的block_size个查询向量,middle表示使用中间的block_size个查询向量。
F.4 查询感知索引搜索的替代实现
实现方式对比。在算法4中,我们对估计的注意力图进行展平(flatten)和归一化,然后确定使分数总和超过给定阈值$\gamma$所需的最小计算预算。我们还探索了一种替代实现,该实现执行逐查询的索引选择。在这种方法中,每个查询块所选键块的累积注意力分数都超过$\gamma$。如表9所示,这种替代方法取得了可比的性能。我们选择带有展平操作的全局方法主要是因为它实现效率高,简化了注意力机制,同时保持了性能上的对等。
Table 9: 查询感知索引搜索不同实现的性能比较。
F.5 最小预算限制
最小预算的影响。我们进行了消融实验来分析最小计算预算的影响。表10表明,当$\gamma$较高时,模型能有效捕捉大部分重要令牌,使得最小预算限制变得不必要。然而,当$\gamma$较小时,一些注意力头选择的令牌数量可能不足,导致它们功能失常。在这种情况下,实施一个最小预算阈值可以显著提升模型性能。此外,这个最小预算限制不需要过高,因为进一步增加预算并不会带来更好的结果。
Table 10: Llama-3.1-8B-Instruct和GLM-4-9B-Chat模型在RULER数据集上使用不同最小预算的性能比较。
F.6 最大预算限制
最大预算的影响。我们进行了消融实验,分析每个注意力头的最大计算预算对模型性能和延迟的影响。表11显示,限制最大计算预算对不同模型的影响各不相同。对LLaMA等模型施加最大预算约束会导致负面影响。相反,对于GLM等模型,这样的约束能维持甚至提升性能。这些发现表明,不同的模型为处理不同长度的上下文进行了优化。
Table 11: 不同模型在RULER数据集上使用不同最大预算的性能比较。
G 延迟分解
计算复杂度分析。FlexPrefill的计算复杂度包括以下几个部分:
* 代表性注意力分数计算:大约为$O(bnd)$,其中$b$是块大小,$d$是隐藏维度,$n$是序列长度。
* 模式搜索:大约为$O(bn)$,包括分块池化注意力分数估计和JS散度计算。
* 稀疏索引构建:大约为$O(n \log n)$,用于排序代表性注意力分数。
* 稀疏注意力计算:大约为$O(\alpha n^2 d)$,其中$\alpha$是稀疏因子。
相比之下,标准密集注意力的计算复杂度为$O(n^2 d)$。FlexPrefill引入的开销(大约$O(\alpha n^2 d) + O(n \log n) + O(bnd)$)被稀疏性带来的计算节省显著抵消。
实际延迟测量。图6展示了FlexPrefill各组件的实际延迟测量,包括代表性注意力分数计算、注意力模式搜索、稀疏索引构建和稀疏注意力计算。在较短的输入长度下,非注意力计算的开销较高。随着输入长度增加,索引搜索构建时间增长,但其所占百分比逐渐下降。
图6:不同上下文长度下的稀疏注意力延迟分解比较。该图显示了不同组件(稀疏注意力、索引搜索、模式搜索和代表性注意力)对不同输入长度下总延迟的贡献。随着输入长度的增加,稀疏注意力计算所花费的时间比例增加,而其他组件的相对贡献减少。
H 稀疏模式分布
模式分布分析。我们分析了Jensen-Shannon距离和稀疏注意力模式分布,重点关注查询感知头的配置。图7表明,分块注意力分数估计随任务类型和上下文长度而变化。图8显示,大多数注意力头使用垂直-斜线模式,而较少的查询感知模式主要出现在模型的第一层。这些消融研究为我们方法的组件和配置提供了深入见解,使得能够做出明智决策以优化稀疏注意力机制的性能和效率。
图7:比较不同注意力头和层之间稀疏注意力模式分布的Jensen-Shannon(JS)距离热图。比较显示了不同上下文长度(128k vs. 32k)和任务类型(任务A vs. 任务B)的情况。颜色越深表示JS距离越低,表明查询感知模式的注意力估计更准确。
图8:不同上下文长度(128k vs. 32k)和任务类型(任务A vs. 任务B)下,各层之间不同注意力模式数量分布的比较。
I 稀疏注意力掩码
动态掩码可视化。我们提出的算法为不同注意力头搜索的稀疏模式是高度动态的,我们在Llama-3.1-8B-Instruct模型上对其进行了可视化。图9展示了典型的垂直-斜线和查询感知注意力头,并表明不同注意力头所需的稀疏率存在巨大差异。
图9:Llama-3.1-8B-Instruct模型中不同注意力头的稀疏掩码可视化。(a) 显示了垂直-斜线头的稀疏掩码。(b) 显示了查询感知头的稀疏掩码,其中有许多偏离特定模式的多样化块。(c) 显示了查询感知头仍可能表现出垂直-斜线模式。
J 稀疏率
稀疏率可视化。我们可视化了Llama-3.1-8B-Instruct模型中不同样本在各种注意力头上的稀疏率。图10显示,不同难度的样本需要不同的稀疏率,并且在注意力头之间的稀疏分布不一致。此外,不同的输入长度表现出不同的稀疏率,较长的输入显示出更高的稀疏率。
图10:Llama-3.1-8B-Instruct模型中不同注意力头的稀疏率可视化。热图显示了不同样本类型(任务A vs. 任务B)和上下文长度(64k vs. 256k)下不同的稀疏度分布。颜色越深表示稀疏度越低。较长的输入(c, d)与较短的输入(a, b)相比,表现出更高的总体稀疏率。
💬 评论讨论
欢迎在这里分享您的想法和见解!