MOMENTKV: Closing the Directional Gap in KV Cache Eviction for Long-Context Inference
MOMENTKV: Closing the Directional Gap in KV Cache Eviction for Long-Context Inference
发表时间: 2026-06 · arXiv:2606.01563
作者/机构: Yu Li1, Binxu Li2, Tian Lan1 (1George Washington University, 2Princeton University)
A1 主要贡献
大型语言模型的自回归生成依赖于KV缓存来存储和重用先前计算出的键(key)和值(value)投影。然而,随着序列长度的增加,KV缓存的内存占用呈线性增长,成为长上下文推理的主要瓶颈。KV缓存驱逐(eviction)通过保留固定大小的键值对子集并丢弃其余部分来解决此问题。
本文识别出,现有驱逐方法导致输出质量下降的主要原因,并非已被现有方法最小化的、被驱逐token上的残余注意力权重,而在于保留token集与被驱逐token集之间的方向不匹配。具体来说,被驱逐的token在实践中通常与被保留的token近乎正交。因此,即使是很小的被驱逐注意力权重,也可能对最终的注意力输出方向分布产生不成比例的巨大影响,并放大为显著的输出错误。这揭示了现有策略的一个根本性限制。
为解决此问题,本文提出了MOMENTKV,该方法对被驱逐的token集维护一组紧凑、小尺寸的矩统计量(moment statistics),包括计数、键均值、值均值以及值-键协方差。
- 在驱逐阶段,这些矩统计量被用来识别那些其值向量已经能够被累积的摘要信息很好地对齐和捕捉的token,从而使被驱逐的集合在几何上保持规整。
- 在推理阶段,这些统计量可以给出一个关于被驱逐部分注意力输出的闭式一阶近似,从而实现对注意力输出的校正。
这种选择性驱逐和精确校正之间形成了一个相互强化的循环。如图1(a)所示,被驱逐的token会增量式地更新矩统计量,这些统计量反过来指导驱逐决策,并用于近似被驱逐的注意力输出以进行校正。图1(b)的PCA分析显示,在一个注意力头中,保留子集和驱逐子集的输出向量近乎正交(余弦相似度约0.10),这导致即使很小的被驱逐注意力权重也会产生巨大的方向性误差,而标准的重归一化无法恢复。
(a) MOMENTKV概览。被驱逐的token增量式地更新紧凑的矩统计量(计数、键均值、值均值和值-键协方差),这些统计量指导驱逐过程倾向于选择那些已经被摘要信息很好捕捉的token,并用于近似被驱逐的注意力输出以进行后驱逐校正。
(b) 一个注意力头中值向量的PCA分析(LLaMA-3-8B, H2O, N=4096, L=128)。保留子输出和驱逐子输出近乎正交(cos θ ≈ 0.10, 84°),导致即使很小的被驱逐注意力权重也会产生巨大的方向性误差,而标准重归一化无法恢复。
图 1: MOMENTKV 维护被驱逐 token 的矩统计量,以共同改善驱逐决策和校正后驱逐注意力输出。
主要贡献:
1. 识别了方向性差距问题:揭示了现有KV缓存驱逐方法中,保留集和驱逐集的注意力子输出之间的方向性不匹配(近乎正交)是导致性能下降的关键因素,而非被驱逐的注意力权重大小本身。
2. 提出了MOMENTKV方法:该方法通过维护关于被驱逐token的紧凑矩统计量(计数、键/值均值、值-键协方差),实现了对被驱逐信息的紧凑恢复。
3. 设计了双重机制:
* 矩感知驱逐(Moment-Informed Eviction):利用矩统计量计算每个token的“矩残差”,优先驱逐那些已经被摘要信息充分代表的token,从而使驱逐集在几何上更规整。
* 归一化校正推理(Normalization-Corrected Inference):利用矩统计量生成一个关于被驱逐部分注意力输出的闭式一阶近似,并用它来校正标准的、仅基于保留集的注意力输出。
- 实现了增强回路:证明了矩感知驱逐和归一化校正推理形成了一个相互增强的循环,前者通过规整化驱逐集来提高后者近似的准确性,后者则为前者提供更可靠的残差计算。
- 验证了有效性:在LongBench和RULER基准上,使用LLaMA-3.1-8B-Instruct和Qwen3-4B-Instruct模型进行的实验表明,MOMENTKV在所有缓存预算下均优于所有基线方法,尤其在激进压缩(小缓存预算)下增益最大。
A3 背景知识与关键观察
2. 预备知识
背景知识。KV缓存的推理过程分为两个阶段。在预填充(prefill)阶段,模型并行处理整个提示(prompt),并为所有输入token填充KV缓存。在解码(decoding)阶段,每个新生成的token会向缓存中追加一个键值对,并对所有缓存条目执行注意力计算以生成下一个token。随着解码的进行,每一步、每一层缓存都会增加一个条目,导致每一步的注意力计算成本相应增加。
KV缓存驱逐。KV缓存驱逐将缓存减少到固定的L个条目。这可以是一次性在预填充后完成,也可以是在解码过程中,当缓存大小超过L时,持续移除得分最低的条目。无论哪种模式,一个评分函数会为每个缓存的token分配一个重要性值,得分最低的token将被永久移除。现有的评分函数主要分为两类:滑动窗口方法保留一组初始的“sink” token和最近的一个窗口内的token,丢弃所有中间条目而不评估其内容;Top-k选择方法则根据累积注意力权重、预填充阶段的观察窗口、层或头之间的自适应预算分配,或考虑值向量几何属性的价值感知标准,对所有缓存的token进行排序。尽管评分方法不同,但所有方法都采用相同的后驱逐推理规则:注意力计算仅在保留的集合上进行重归一化,被驱逐的条目被永久删除。
符号与公式。为了形式化驱逐对注意力输出的影响,考虑一个单注意力头,其头部维度为d。在某个解码步骤,缓存中有N个键值对。给定当前查询$q \in \mathbb{R}^d$,缓存的键$\{k_i\}_{i=1}^N$,以及缓存的值$\{v_i\}_{i=1}^N$,注意力logits为$s_i = q^\top k_i / \sqrt{d}$,配分函数为$Z = \sum_{i=1}^N \exp(s_i)$。注意力权重$\alpha_i = \exp(s_i)/Z$产生完整的注意力输出,该输出是缓存值向量在注意力分布下的质心:
驱逐后的分解。驱逐后,令$R \subset [N]$表示大小为L的保留集,$E = [N] \setminus R$表示被驱逐集。将求和分解到R和E上,并用各自的配分函数$Z_R = \sum_{i \in R} \exp(s_i)$和$Z_E = \sum_{i \in E} \exp(s_i)$进行归一化,得到两个子输出:
子输出的几何意义。每个子输出是其对应token子集的值向量的注意力加权质心。当保留和被驱逐的token对应的值向量位于$\mathbb{R}^d$的不同区域时,这两个质心$f_R(q)$和$f_E(q)$可能指向截然不同的方向,这一性质对第3节的误差分析至关重要。
完整输出的精确分解。由于$Z = Z_R + Z_E$,完整的注意力输出可以精确地分解为以下凸组合:
其中$w_R = Z_R/Z$和$w_E = Z_E/Z$是混合权重,且$w_R + w_E = 1$。这些权重对应于分配给每个子集的总注意力权重的比例。标准的驱逐方法仅输出$f_R(q)$,这等同于设置$w_E=0$并将被驱逐的注意力权重完全重新分配给保留的token。这个分解将驱逐误差分为了一个控制丢失分量大小的标量因子$w_E$和一个控制其方向的向量值因子$f_E(q) - f_R(q)$。
3. Token驱逐的方向性成本
驱逐误差的分解。根据公式(3),完整注意力输出为$f(q) = w_R f_R(q) + w_E f_E(q)$。标准驱逐丢弃了被驱逐的分量,只返回$f_R(q)$。从完整输出中减去它,并使用$w_R = 1 - w_E$,得到的驱逐误差为:
这个误差分解为一个标量权重项$w_E \in [0, 1]$和一个向量值散度项$f_E(q) - f_R(q) \in \mathbb{R}^d$。权重项衡量了因驱逐而损失的注意力权重,而散度项衡量了两个子集在表示上下文方面的差异程度。
图 2: 在 LLaMA-3-8B 上使用 H2O 选择策略(N=4096)的驱逐误差分析。(a, e) 被驱逐的权重 $w_E$ 很小,在 L=128 时平均低于 0.08,并且随着预算的增加而减小。(b, c) 尽管权重很小,保留子输出和驱逐子输出在所有层中仍然近乎正交,cos($f_E$, $f_R$) ≈ 0。(d, f) 近乎正交性将残余权重放大到平均超过 20% 的相对误差,证实了方向性差距而非被驱逐的权重是主要的误差来源。
误差的平行与垂直分量。为了揭示这两个因素的不同作用,设$\theta = \angle(f_E, f_R)$,并将散度分解为平行于$f_R$和垂直于$f_R$的分量:
其中$\hat{n}$是在$f_E$和$f_R$所张成的平面上与$f_R$正交的单位向量。平行分量沿着$f_R$方向重新缩放输出,重归一化可以部分补偿它。然而,垂直分量将输出移向一个完全在$f_R$张成空间之外的方向。由于标准驱逐完全在保留集上操作,它没有任何机制来恢复这种正交位移,使得垂直误差成为仅重归一化范式下驱逐的不可约成本。
误差的平方分解。将上述分解代入公式(4)并利用两个分量的正交性,平方误差可以分解为:
平行误差在$\|f_E\| \cos\theta = \|f_R\|$时消失,但垂直误差仅在$\sin\theta = 0$时消失,即当两个子输出完全对齐时。通过$\|f\|$进行归一化,相对误差满足:
其中,放大系数$\gamma(\theta)$依赖于角度和子输出的范数,但与$w_E$无关。这种乘法结构揭示了一个关键洞见:即使$w_E$很小,一个大的$\gamma(\theta)$也能将残余权重放大为显著的输出误差。我们在LLaMA-3-8B上使用H2O选择策略(N=4096)对这两个因素进行了实证检验,结果总结在图2中。
现有选择策略的局限。现有选择策略控制$w_E$但不能控制$\gamma(\theta)$。图2(a)显示,在L=128时,大多数层的$w_E$低于0.05,全网络平均低于0.08。图2(e)确认$w_E$随着预算增加和上下文变短而单调减小。然而,角度$\theta$始终很大。图2(b,c)显示,$\cos\theta$在所有层中都在零附近振荡,且$1 - \cos\theta$在整个网络中都饱和在1.0附近,证实了$\theta \approx \pi/2$。
近正交性导致乘法放大。当$\theta = \pi/2$时,公式(6)中的平行项简化为$\|f_R\|^2$,垂直项等于$\|f_E\|^2$,得到$\gamma(\pi/2) = \sqrt{\|f_E\|^2 + \|f_R\|^2} / \|f\|$,这个值仍然是1的量级,因为两个子输出都是有界值向量的注意力加权平均。图2(d)凭经验证实了这种放大效应:全网络平均低于0.08的权重产生了平均超过20%的相对误差,在最浅的层中峰值超过60%。图2(f)显示这种模式在不同预算下都存在,即使在L=1024时,浅层仍然会产生超过10%的误差。
更好的选择策略加剧了方向性差距。近正交性的产生是因为top-k选择将保留集集中在那些其键与近期查询方向对齐的token上,而被驱逐集则累积了跨越更广泛、互补子空间的token。改进选择策略只会放大这种不对称性:一个更具选择性的保留集会缩小$f_R$的方向跨度,同时将$f_E$推向一个更加互补的子空间,这导致即使$w_E$减小,$\gamma(\theta)$反而会增加。因此,这两个误差因素在结构上是相互对立的,仅仅减少$w_E$会面临收益递减。解决这个瓶颈需要近似$f_E$的垂直分量并在推理时恢复它。
A2 方法细节
4. MOMENTKV:对被驱逐信息的紧凑恢复
4.1 通过中心化Softmax展开实现一阶近似
中心化Softmax展开。回顾被驱逐的注意力输出$f_E(q)$是驱逐集上值的softmax加权平均。由于原始的logits可能很大,我们围绕平均被驱逐键 $\bar{k} = n_e^{-1} \sum_{i \in E} k_i$ 进行中心化,并定义中心化logits $\delta_i = q^\top(k_i - \bar{k}) / \sqrt{d}$。根据构造,$\sum_{i \in E} \delta_i = 0$,并且其量级通常小得多。由于softmax具有平移不变性,记$s_i = \bar{s} + \delta_i$,其中$\bar{s} = q^\top \bar{k} / \sqrt{d}$,公共因子$\exp(\bar{s})$在比率中被消去:
这使得$f_E$完全用中心化logits表示。
一阶泰勒展开。应用一阶泰勒展开$\exp(\delta_i) \approx 1 + \delta_i$,$\sum_{i \in E} \delta_i = 0$的零和性质使得分母塌缩为$n_e$。在分子中,常数项产生平均被驱逐值$\bar{v} = n_e^{-1} \sum_{i \in E} v_i$,线性项产生一个涉及经验值-键协方差$\tilde{S} = \sum_{i \in E} v_i(k_i - \bar{k})^\top \in \mathbb{R}^{d \times d}$的矩阵向量乘积。除以$n_e$得到闭式近似:
近似的解释。第一项$\bar{v}$是一个与查询无关的基线,它恢复了被驱逐子输出的平均方向,从而减小了校正后输出与完整输出之间的角度$\theta$(见公式(5))。第二项$\tilde{S}q / (n_e \sqrt{d})$是一个查询自适应的校正,它捕捉了基线本身无法表示的垂直方向上的变化,直接缩小了公式(6)中的$\|f_E\| \sin\theta$因子。这是因为不同的查询会激活被驱逐键的不同子集;协方差矩阵编码了键和值之间按方向的这些相关性,使得近似能够转向与每个查询最相关的值向量。该近似的误差尺度为$O(\sigma^2)$,其中$\sigma = \max_{i \in E} |\delta_i|$,正式的界限在附录A中给出。
图 3: 在 LLaMA-3-8B 上使用 Qasper 数据集(L=128)的经验验证。(a) 单样本驱逐损失:H2O vs MOMENTKV。(b) 分层误差和降低百分比。(c) $f_E$ 和 $\hat{f}_E$ 之间的余弦相似度。
高效的统计量更新。至关重要的是,公式(9)只依赖于四个运行中的和:计数$n_e$、键的和$s_k = \sum_{i \in E} k_i$、值的和$s_v = \sum_{i \in E} v_i$以及外积的和$S = \sum_{i \in E} v_i k_i^\top$。每个统计量在驱逐时通过一次加法更新,无需存储或重新访问被驱逐的token。在查询时,均值和协方差通过恒等式$\tilde{S} = S - s_v s_k^\top / n_e$恢复,每个头的总存储为$O(d^2)$,与上下文长度无关。
4.2 矩感知驱逐与归一化校正推理
矩残差。标准驱逐仅根据注意力权重对token进行排序,忽略了其值的方向性内容。公式(9)的矩统计量提供了一种原则性的方法,来区分那些其值已被摘要信息很好地捕捉的token和那些携带新方向信息的token。对于每个保留的token j,我们计算一个矩残差:
该残差衡量了其真实值与在其自身键上评估的仿射模型预测之间的差异。范数$\|r_j\|$量化了token j携带的、超出已累积统计信息的新信息量:一个小的残差表明该token的值位于由$\bar{v}$和$\tilde{S}$的列空间张成的仿射子空间附近,而一个大的残差则表明其方向性内容若被驱逐将会丢失。
矩感知驱逐评分。驱逐分数将注意力权重与残差范数结合起来,即$score(j) = \alpha_j \cdot \|r_j\|$,得分最低的token被首先驱逐。这种乘法形式确保了高注意力权重的token即使其残差很小也会被保留,因为它对$f_R$贡献显著;而大残差的token即使其注意力权重很低也会被保留,因为驱逐它会降低矩近似的质量。因此,被驱逐的集合逐渐累积那些能被矩模型很好预测的token,这直接抑制了中心化logit的分布范围$\sigma$,并收紧了$O(\sigma^2)$的近似界。
归一化校正推理。转向推理阶段,我们将公式(9)中的$\hat{f}_E$代入公式(3)的分解中,以恢复对完整注意力输出的估计。保留部分的配分函数$Z_R$可从标准注意力计算中获得。对于被驱逐部分的配分函数,我们对凸函数指数函数应用Jensen不等式:
其中最后一个等式使用了零和性质$\sum_{i \in E} \delta_i = 0$。这产生了一个下界$\hat{Z}_E = n_e \cdot \exp(q^\top \bar{k} / \sqrt{d}) \le Z_E$,其相对误差为$O(\sigma^2)$,由logit方差控制(见附录B)。估计了两个分量后,校正后的输出形式为:
自调节偏置。由于$\hat{Z}_E$低估了$Z_E$,权重$\hat{w}_R$会略微超过$w_R$,从而将更多权重放在精确的保留集输出上。这种偏置是自调节的:当$\sigma$较大时,Jensen不等式的差距变大,$\hat{w}_E$缩小,限制了不太准确的近似的影响;当$\sigma \to 0$时,权重收敛到它们的真实值。与设置$w_E = 0$并产生$w_E \cdot \gamma(\theta)$相对误差的标准驱逐相比,公式(12)同时减小了有效角度$\theta$和放大系数$\gamma(\theta)$,将总误差缩小到任何仅改进$w_E$的方法都无法达到的程度。
增强回路。这两个机制形成了一个增强回路:矩感知驱逐使$\sigma$保持较小,从而收紧了近似$\hat{f}_E \approx f_E$和对$\hat{Z}_E$的Jensen界,而改进的校正为后续的驱逐决策产生了更可靠的残差。图3提供了经验证据:与H2O相比,MOMENTKV显著降低了单样本的驱逐损失,实现了持续的逐层误差减少,并在所有层中保持了真实与近似的被驱逐输出之间的高度余弦相似性。完整的流程在算法1中给出。
算法 1 注意力头的解码步骤
需要: 保留的缓存 $R = \{(k_i, v_i)\}$; 矩统计量 $(n_e, s_k, s_v, S)$; 新 token $x$; 预算 $L$
确保: 注意力输出 $\hat{f}(q)$; 更新后的 $R$ 和统计量
1: $q, k_{new}, v_{new} \leftarrow W_Q x, W_K x, W_V x$
2: $R \leftarrow R \cup \{(k_{new}, v_{new})\}$
阶段 1: 矩感知驱逐
3: while $|R| > L$ do
4: if $n_e > 0$: $\bar{k} \leftarrow s_k / n_e, \bar{v} \leftarrow s_v / n_e, \tilde{S} \leftarrow S - s_v s_k^\top / n_e$; else $\bar{v}, \tilde{S} \leftarrow 0$
5: 计算 $\alpha \leftarrow \text{softmax}(q^\top[k_j]_{j \in R} / \sqrt{d})$
6: for 每个 $j \in R$ do
7: $r_j \leftarrow v_j - \bar{v} - \tilde{S}(k_j - \bar{k}) / (n_e \sqrt{d})$ ▷ 矩残差
8: $score(j) \leftarrow \alpha_j \cdot \|r_j\|$ ▷ 分数低 = 可以安全驱逐
9: end for
10: 选择 $j^* \leftarrow \arg \min_{j \in R} score(j)$
11: 更新 $n_e \leftarrow n_e + 1$
12: 更新 $s_k \leftarrow s_k + k_{j^*}, s_v \leftarrow s_v + v_{j^*}$
13: 更新 $S \leftarrow S + v_{j^*} k_{j^*}^\top$ ▷ 秩一加法
14: 从 $R$ 中移除 $j^*$: $R \leftarrow R \setminus \{j^*\}$
15: end while
阶段 2: 归一化校正推理
16: $f_R, Z_R \leftarrow \text{Attention}(q, R)$
17: if $n_e = 0$ then return $f_R$ ▷ 尚无被驱逐的 token
18: $\hat{Z}_E \leftarrow n_e \cdot \exp(q^\top \bar{k} / \sqrt{d})$ ▷ 通过 Jensen 不等式得到下界
19: $\hat{f}_E \leftarrow \bar{v} + \tilde{S}q / (n_e \sqrt{d})$ ▷ 矩近似
20: $\hat{w}_R \leftarrow Z_R / (Z_R + \hat{Z}_E)$ ▷ 混合权重
21: return $\hat{w}_R f_R + (1 - \hat{w}_R) \hat{f}_E$
A4 实验环境
-
基准测试:
- LongBench: 一个双语多任务基准,涵盖单文档问答(SQA)、多文档问答(MQA)、摘要、少样本学习(Few-shot)、合成任务和代码补全。
- RULER: 补充LongBench,用于在受控上下文长度下探测检索、变量跟踪和多跳推理能力。
-
模型:
- LLaMA-3.1-8B-Instruct: 128K上下文窗口,32个查询头,8个KV头 (GQA)。
- Qwen3-4B-Instruct-2507: 256K上下文窗口,32个查询头,8个KV头 (GQA)。
-
硬件配置:
- GPU: 2台 NVIDIA H200 GPU。
-
软件配置与对比方法:
- 实现: 基于Ada-KV的代码实现。
- 协议: 采用SnapKV的预填充协议:在提示末尾的32个token观察窗口内累积注意力分数;保留第一个token作为注意力“sink”;在评分前将相邻token合并为大小为4的块;根据每个头的注意力权重,在层内自适应分配缓存预算。
- 缓存预算(L): {128, 256, 512, 1024}。
- 对比基线: H2O, SnapKV, PyramidKV, Ada-KV。
- 初始化: 每次生成开始时,所有矩统计量初始化为零。
A4 实验结果
主要结果
LongBench (Table 1): 在两个模型上,MOMENTKV在所有缓存预算水平上都取得了最高的平均分。在LLaMA-3.1-8B上,当L=128时,它比最强的基线Ada-KV高出1.35分,保留了全缓存性能的94.1%。增益在需要关注分散在整个上下文中的特定事实的问答(QA)和少样本任务上最为显著。摘要和代码任务的优势较小,因为它们更依赖于可能被保留在缓存中的局部上下文。在Qwen3-4B上,MOMENTKV同样持续领先,尤其在L=128时优势最大,证实了其益处可跨模型架构和规模泛化。
表 1: 在两个模型上,四种缓存预算下的LongBench结果。SQA:单文档问答,MQA:多文档问答,Few:少样本学习。Avg是在所有六个类别上的平均值。MOMENTKV在每个预算水平上都取得了最高的平均分,在激进压缩下增益最大(在LLaMA-3.1-8B上,L=128时比Ada-KV高+1.35)。最佳结果加粗,次佳结果加下划线。
预算扩展分析 (Figure 4): 如图4所示,随着缓存预算的增加,MOMENTKV的优势有所缩小,在L=1024时,其与Ada-KV的差距降至+0.59。这个趋势与公式(4)的分析一致:更大的预算减少了被驱逐的注意力权重$w_E$,留给方向性误差校正的空间也变小了。然而,即使在L=1024,所有基线方法得分都收敛到彼此相差1分以内时,MOMENTKV仍然保持着稳定的领先地位,这表明矩校正捕捉到了仅靠改进选择策略无法恢复的信息。
图 4: 在 LLaMA-3.1-8B 上,不同缓存预算下的 LongBench 平均分。优势在 L 较小时最大,并随着预算的增加而缩小。
RULER (Table 2): 在L=128时,MOMENTKV在RULER基准上的表现优于Ada-KV,在LLaMA-3.1-8B上高出+3.4分,在Qwen3-4B上高出+3.5分。这比在LongBench上的优势比例更大。RULER任务要求精确检索嵌入在受控位置的特定token,因此对被驱逐信息的丢失特别敏感。这验证了矩校正在检索密集型场景中价值最大。
表 2: RULER 分数 (16K, L=128)。MOMENTKV 比 Ada-KV 领先 +3.4 和 +3.5。
消融和进一步分析
组件消融 (Table 3): 表3分析了MOMENTKV的两个核心组件:矩感知驱逐(MI)和归一化校正推理(NC)。以SnapKV为基线,在L=128时,单独使用MI带来了+0.57的提升,而单独使用NC则贡献了+2.41的提升。这证实了恢复被驱逐输出的方向比优化选择策略更为重要,NC的贡献大约是MI的4倍,与第3节的分析一致,即主要误差源是放大系数$\gamma(\theta)$,而只有NC能解决这个问题。将两者结合使用,总提升达到+3.17,超过了两者独立增益之和(+0.19),这得益于4.2节中描述的增强回路。
表 3: 在 LongBench 平均分上对组件进行的消融实验 (LLaMA-3.1-8B)。MI:矩感知驱逐,NC:归一化校正推理。
近似阶数和开销分析 (Table 4): 表4比较了不同阶数近似的效果和开销。零阶变体(仅使用均值$\bar{v}$,每头O(d)存储)已经比SnapKV提升了+0.94,表明即使是与查询无关的校正也能恢复有用的被驱逐信息。一阶近似增加了每头O(d^2)的存储(总计4.1MB,占保留缓存的6.4%)和1.6ms的延迟,但通过捕捉与查询相关的方向变化,其增益是零阶的三倍多。这些成本都由头维度d决定,不随上下文长度增长,使得MOMENTKV对于任意长序列都实用,同时相比全缓存,每token延迟降低了超过65%。
表 4: 在 LongBench 平均分上的近似阶数和开销分析 (LLaMA-3.1-8B, L=128)。基线:SnapKV。
A5 结论
本文证明了KV缓存驱逐中质量损失的一个关键来源不是token选择不充分,而是保留和被驱逐子输出之间的方向性不匹配,而仅靠重归一化的推理无法完全解决这个问题。MOMENTKV通过维护被驱逐集上的紧凑矩统计量来弥合这一差距,这些统计量发挥双重作用:引导驱逐过程倾向于选择那些已经被累积摘要很好捕捉的token,并为后驱逐注意力输出提供一个闭式的一阶校正。这两个机制形成了一个增强回路,因为矩感知驱逐抑制了控制近似误差的中心化logit分布范围$\sigma$,而准确的校正为后续的驱逐决策产生了更可靠的残差。在LongBench和RULER上使用两种模型系列的实验证实了在每个缓存预算下都有一致的改进,尤其是在检索密集型任务中进行激进压缩时增益最大。未来的研究方向包括每头自适应的近似阶数,以及与量化和低秩分解等正交技术的集成。
A6 附录
A. 一阶近似的推导
A.1 中心化Softmax和一阶展开
推导细节。Softmax函数对于加性平移是不变的。定义平均被驱逐键$\bar{k} = \frac{1}{n_e} \sum_{i \in E} k_i$,平均logit $\bar{s}(q) = q^\top \bar{k} / \sqrt{d}$,以及中心化logit $\delta_i = (q^\top(k_i - \bar{k}))/\sqrt{d}$。根据构造,对于任何查询q,都有$\sum_{i \in E} \delta_i = 0$。这个零和性质是关键。由于$s_i = \bar{s} + \delta_i$,公共因子$\exp(\bar{s})$在softmax比率中被消去,得到$f_E(q) = \frac{\sum_{i \in E} v_i \exp(\delta_i)}{\sum_{i \in E} \exp(\delta_i)}$。
泰勒展开。我们应用泰勒展开$\exp(\delta_i) = 1 + \delta_i + \frac{1}{2}\delta_i^2 e^{\xi_i}$,只保留一阶项。设$\sigma = \max_{i \in E} |\delta_i|$。对于分母,零和性质消除了线性项:$\sum_{i \in E} \exp(\delta_i) = n_e + O(\sigma^2)$。对于分子,展开后为$\sum v_i(1+\delta_i) + \text{高阶项}$。常数项产生$n_e \bar{v}$(其中$\bar{v} = \frac{1}{n_e}\sum v_i$),线性项$\sum v_i \delta_i$保留下来,因为它由不同的值向量$v_i$加权。
闭式近似。将$\delta_i$代入线性项并分解,得到$\left(\sum_{i \in E} v_i(k_i - \bar{k})^\top\right) q / \sqrt{d} = \tilde{S}q/\sqrt{d}$,其中$\tilde{S} = \sum_{i \in E} v_i(k_i - \bar{k})^\top$是经验交叉协方差。将分子近似$\sum v_i \exp(\delta_i) \approx n_e\bar{v} + \tilde{S}q/\sqrt{d}$除以分母近似$n_e$,得到$f_E(q) \approx \bar{v} + \frac{\tilde{S}q}{n_e\sqrt{d}}$,即公式(9)。
统计量的高效计算。矩阵$\tilde{S}$可以从四个运行和中计算,而无需重新访问任何被驱逐的token。这些充分统计量是:$n_e$(计数),$s_k = \sum k_i$(键和),$s_v = \sum v_i$(值和),$S = \sum v_i k_i^\top$(外积和)。使用$s_v = n_e\bar{v}$和$\bar{k} = s_k/n_e$,我们有$\tilde{S} = S - s_v s_k^\top / n_e$。当一个token $j^*$被驱逐时,统计量通过简单的加法更新:$n_e \leftarrow n_e+1, s_k \leftarrow s_k+k_{j^*}, s_v \leftarrow s_v+v_{j^*}, S \leftarrow S + v_{j^*}k_{j^*}^\top$。每次更新成本为$O(d^2)$。
存储开销。存储$s_k, s_v$需要$2d$个值,主要成本是存储$S \in \mathbb{R}^{d \times d}$,需要$d^2$个值。如表5所示,对于d=128和16位精度,每个头的总开销约为128个KV对的等效存储,与上下文长度N无关。这个开销是固定的,因此当缓存预算L增加时,其相对占比会下降。
表 5: 矩统计量存储开销。保留缓存大小按 L=128 和 16 位精度计算。
A.2 定量误差界
误差界推导。我们旨在界定差距$\|\hat{f}_E(q) - f_E(q)\|$。将$f_E$写为$A/B$,其中$A$和$B$是精确的分子和分母。$\hat{f}_E$是近似的分子除以近似的分母。经过代数运算,误差可以表示为$\frac{A}{B} - \frac{A-R_{num}}{n_e} \approx \frac{R_{num}}{n_e} - \frac{A \cdot R_{den}}{n_e^2}$,其中$R_{num}$和$R_{den}$是泰勒展开的余项,均为$O(\sigma^2)$。因此,最终的误差界也为$O(\sigma^2)$。
几何解释。通过柯西-施瓦茨不等式,$\sigma \le \|q\| \cdot r_{max} / \sqrt{d}$,其中$r_{max} = \max_{i \in E} \|k_i - \bar{k}\|$是键偏离其质心的最大距离。因此,误差界由三个因素控制:1) 键的散布$r_{max}$,矩感知驱逐通过优先驱逐键接近$\bar{k}$的token来直接促进较小的$r_{max}$;2) 查询范数$\|q\|$;3) 头部维度$d$,其$1/\sqrt{d}$缩放因子减小了logits的量级。
与驱逐标准的联系。驱逐分数中使用的矩残差$r_j$直接与误差界相关。一个$r_j$小的token意味着其值能被仿射模型很好地预测。驱逐这样的token会增加$n_e$,但不会显著增加$V_{max}$或$r_{max}$,从而保持近似的高质量。因此,驱逐分数$\alpha_j \cdot \|r_j\|$联合优化了保留集输出质量(通过$\alpha_j$)和近似质量(通过$\|r_j\|$)。
B. 对被驱逐配分函数的Jensen界
B.1 推导与紧致性
下界推导。真实的被驱逐配分函数是$Z_E = \sum_{i \in E} \exp(s_i)$。将$s_i = \bar{s} + \delta_i$代入,得到$Z_E = \exp(\bar{s}) \sum_{i \in E} \exp(\delta_i)$。对凸函数$\exp(\cdot)$应用Jensen不等式:$\frac{1}{n_e}\sum \exp(\delta_i) \ge \exp(\frac{1}{n_e}\sum \delta_i)$。由于$\sum \delta_i = 0$,右边为$\exp(0)=1$。因此,$\sum \exp(\delta_i) \ge n_e$,这给出了下界$Z_E \ge n_e \exp(\bar{s}) = \hat{Z}_E$。
紧致性分析。通过对$\exp(\delta_i)$进行二阶泰勒展开并利用零和性质,可以表明$Z_E / \hat{Z}_E \approx 1 + \frac{1}{2} \text{Var}(\delta)$,其中$\text{Var}(\delta) = \frac{1}{n_e}\sum \delta_i^2$是中心化logits的方差。因此,相对误差由logit方差决定,尺度为$O(\sigma^2)$。这个方差又可以表示为$q$在被驱逐键协方差$\Sigma_k$上的投影:$\text{Var}(\delta) = \frac{1}{d} q^\top \Sigma_k q$。
B.2 对混合权重的影响
自调节机制。低估$Z_E$会导致估计的保留权重$\hat{w}_R$高于真实权重$w_R$。权重偏差是$\sigma$的二阶项,且与$w_R \cdot \hat{w}_E$成正比,在$w_R \gg w_E$的典型情况下这个偏差很小。这种机制是自调节的:
- 当$\sigma \to 0$(驱逐集紧凑)时,$\hat{f}_E \to f_E$, $\hat{Z}_E \to Z_E$, $\hat{w}_R \to w_R$。校正后的输出$\hat{f}(q)$收敛到真实输出$f(q)$。
- 当$\sigma \to \infty$(近似不可靠)时,Jensen不等式的差距变大,$\hat{Z}_E \ll Z_E$。这导致估计的被驱逐权重$\hat{w}_E \to 0$,从而$\hat{w}_R \to 1$。系统会逐渐减少对不可靠校正项$\hat{f}_E$的依赖,平滑地退化为标准驱逐,而不是用一个坏的校正来污染输出。
因此,Jensen下界充当了一个自然的正则化器,根据近似的可靠性自动校准其影响。
C. 计算复杂度和数值考虑
C.1 每步复杂度分析
复杂度对比。如表6所示,所有方法共享$O(Ld)$的注意力计算成本。基线方法的驱逐成本为$O(L)$或$O(wL)$。MOMENTKV引入了两个额外成本:1) 驱逐时为每个保留token计算矩残差,需要$O(d^2)$的矩阵-向量乘法,总计$O(Ld^2)$;2) 推理时计算校正项$\hat{f}_E$和更新统计量,成本为$O(d^2)$。尽管$O(Ld^2)$在渐近意义上更大,但实际中$d=128$的矩阵乘法在现代GPU上非常快,且驱逐循环每生成一个token最多运行一次。如表4所示,实际延迟增加是可控的(+1.6ms)。
表 6: 单个注意力头的每步复杂度。L:缓存预算,d:头部维度,w:观察窗口大小。存储指保留缓存之外的辅助开销。
存储扩展性。与基线方法需要$O(L)$的辅助存储不同,MOMENTKV的矩统计量存储是$O(d^2)$,与L和上下文长度N无关。这使得其在L较大时相对开销非常小。
C.2 数值稳定性
潜在问题与缓解。矩统计量使用半精度(FP16/BF16)存储。两个潜在问题是:1) 在累加外积和$S$时发生溢出;2) 在恢复协方差$\tilde{S} = S - s_v s_k^\top / n_e$时发生灾难性抵消。
- 溢出:对于FP16/BF16,其表示范围足以处理实际上下文长度下的累加和,因为单个键值向量的元素通常是O(1)量级。
- 灾难性抵消:当中心化协方差$\tilde{S}$相对于$S$很小时可能发生。通过矩感知驱逐(避免极端集中的驱逐集)和数值保护(对计算出的$\tilde{S}$中小于阈值的条目钳位为零)来缓解。
- 实验验证:用FP16和FP32进行矩统计的实验对比显示,在LongBench上的平均分差异小于0.05,证实半精度引入的数值误差可以忽略不计。
- 对数域计算:为避免计算$\exp(q^\top\bar{k}/\sqrt{d})$时溢出,混合权重在对数域使用logsumexp进行稳定计算。
D. 扩展实验结果
D.1 扩展行为与差距分析
绝对优势与恢复百分比。如表7所示,MOMENTKV相对于Ada-KV的绝对优势($\Delta_{Ada}$)随着预算L的增加而减小(从L=128的+1.35降至L=1024的+0.59),因为可供恢复的方向信息变少。然而,MOMENTKV弥补的、相对于Ada-KV到全缓存性能之间差距的恢复百分比,却随着L的增加而单调上升(从31.7%到59.6%)。这是因为在更大的预算下,被驱逐的集合更小、更同质,使得一阶近似更准确。
表 7: 在 LongBench 平均分上使用 LLaMA-3.1-8B 的差距分析。$\Delta_{Ada}$:比 Ada-KV 的优势。Recovery:MOMENTKV 相对于 Ada-KV 弥补的到全缓存(49.29)剩余差距的比例。
D.2 任务级别分解
任务敏感性。如表8所示,MOMENTKV的增益在不同任务类型上分布不均。
- 单文档QA和少样本学习:增益最大。这些任务需要从长文档中检索稀疏的事实,这些信息很容易被驱逐。矩校正恢复了这些丢失的方向信息。
- 摘要和代码任务:增益最小。这些任务主要依赖局部上下文,大部分相关信息已存在于保留缓存中,留给校正的空间较小。
- 合成任务和RULER:增益较大。这些任务需要精确定位和检索,最大程度地暴露了信息丢失问题,从而凸显了矩校正的价值。
表 8: 在 LongBench 上,L=128 时 MOMENTKV 相对于 Ada-KV 的各类别改进。
D.3 组件交互分析
超加性效应。组件消融实验揭示了矩感知驱逐(MI)和归一化校正推理(NC)之间的超加性互动。在L=128时,两者结合的增益(+3.17)超过了它们独立增益之和(+0.57 + +2.41 = +2.98),多出的+0.19量化了增强回路的强度。这种超加性效应在最紧凑的预算下最强。
NC的主导作用。在所有预算下,NC的贡献始终是MI的3到4倍。这与第3节的分析一致:现有选择方法已将$w_E$降得很低,主要误差来自$\gamma(\theta)$,而只有NC能直接解决这个问题。
组合的鲁棒性。单独使用MI有时可能在某些任务(如摘要)上损害性能,但与NC结合后,校正可以补偿任何次优的驱逐决策,使组合系统在所有类别上都表现出色,强调了将两个机制一同部署的重要性。
💬 评论讨论
欢迎在这里分享您的想法和见解!