作者/机构: Ranjith Chodavarapu, Lei Xu (Department of Computer Science, Kent State University, USA)
核心问题: 在自回归Transformer推理中,KV缓存是一项普遍采用的优化技术,长期以来被认为与无缓存计算在数值上是等价的。然而,本文指出在标准的FP16精度下,这一假设不成立。
研究目标: 本文旨在系统性地分析和揭示在FP16精度下,使用KV缓存(cache-ON)与不使用KV缓存(cache-OFF)的自回归推理之间存在的确定性数值差异,并探究其根本原因、传播机制、行为影响以及在模型架构中的具体位置。
创新与贡献:
* 行为特征描述: 研究表明,在所有测试模型(LLaMA-2-7B, Mistral-7B-v0.3, Gemma-2-2B)和采样策略(贪心、top-k、top-p)下,KV缓存都导致了100%的token序列差异。在评估协议下,有8/9的情况下,启用缓存(cache-ON)获得了更高的准确率(经Bonferroni校正的McNemar检验,p < 0.05),这表明执行路径存在系统性偏差,而非随机发散。
* 机制性根本原因: 通过受控的证伪实验,确定了FP16的累加顺序是导致差异的根本原因。切换到FP32精度后,KV差异降低了超过八个数量级,降至2.5 × 10−8的噪声水平,token翻转率降至0.0%,证实了FP16的非结合性是唯一的因果驱动因素。平均相对累加误差不随序列长度变化(约0.036%),表明误差幅度由架构决定且与输入无关。
* 决策边界解析: 解决了翻转步数与KL散度之间的明显矛盾:早期的token翻转通常伴随着更高的KL散度。研究证实,这是因为模型的不确定性(即接近决策边界)同时导致了早期翻转和高KL散度,而非差异幅度本身。这解释了数值差异在何时会产生行为上的影响。
* 因果定位: 通过对残差流(residual stream)的激活补丁(activation patching)实验发现,即使同时修补所有层的隐藏状态,也无法恢复无缓存的计算轨迹,中位恢复率低于1%。这一定位了差异的根源在于有状态的KV缓存本身,而非残差流中的注意力头或MLP层现象,并表明正确的干预需要直接修补KV缓存。
综上所述,这些发现确立了FP16 KV缓存推理与重计算在根本上不等价,并为理解现代LLM推理系统中的数值不稳定性提供了一个机制性的框架。
KV缓存机制。自回归Transformer使用键值(Key-Value, KV)缓存来加速解码过程。先前计算过的输入前缀中每个token的注意力键和值被存储起来,并在后续步骤中重用。无缓存(cache-OFF)模式下,每个token的生成都需要对整个前缀重新计算注意力;而有缓存(cache-ON)模式下,只需关注缓存中已有的前缀表示。尽管两种行为在数学上是等价的,但它们的实现顺序、核函数布局和内存布局都不同,因此会导致不同的浮点数累加结果。这种等价性通常只是被假设,而未经过跨模型架构和精度格式的系统性测试。
注意力变体引入的差异。注意力机制的变体,如多查询注意力(MQA)【索引18,Fast transformer decoding: One write-head is all you need,2019】和分组查询注意力(GQA)【索引2,Gqa: Training generalized multi-query transformer models from multi-head checkpoints,2023】,进一步引入了差异。MQA使用一个所有查询头共享的KV头,而GQA则将查询头分组,每组共享一个KV头。这些变体改变了归约(reduction)结构,从而影响了FP16舍入误差的传播方式。因此,KV缓存建立了一个有状态的计算路径:缓存张量内微小的数值差异会被保留下来,影响后续解码的每一步。
浮点数计算的非结合性。Transformer推理作为一种计算机程序,依赖于浮点数运算,而浮点数运算不满足结合律【索引10,What every computer scientist should know about floating-point arithmetic,1991】:
$$\begin{aligned} \begin{cases} (a+b)+c \neq a+(b+c) \\ (a \times b) \times c \neq a \times (b \times c) \end{cases} \end{aligned}$$低精度对计算顺序的敏感性。与FP32等高精度格式相比,FP16等低精度浮点数对计算顺序更为敏感。因此,执行顺序的变化(如cache-ON与cache-OFF)会产生系统性的影响。以往的研究虽然探讨了降低精度和量化(如【索引16,Mixed precision training,2017】、【索引24,Efficient streaming language models with attention sinks,2023】),但主要关注效率而非推理的正确性。在深度Transformer中,微小的初始偏差会逐层累积,并通过KV缓存跨解码步骤持续存在,从而产生结构化的差异,而非随机噪声。
决策边界与行为差异。隐藏状态层面的差异不一定会导致输出层面的差异。模型的决策是基于对数概率的差异,一个微小的变化只有在多个竞争的输出token概率非常接近时(即模型处于决策边界上)才会影响最终的argmax选择。
不确定性与扰动影响。实际上,行为差异不仅取决于数值漂移的大小,还取决于logit裕度(margin)的结构。一个早期的扰动之所以可能改变下游的决策,更可能是因为模型本身不确定,而不是因为扰动本身很大。
GQA与MQA机制。GQA和MQA通过在多个查询头之间共享同一组键(key)和值(value)头来减少内存带宽需求。MQA是GQA的一个特例,其中所有查询头(H个)共享一个KV头(共享比例R=H)。GQA是MQA的泛化,可以选择一个共享比例1 ≤ R < H。当R=1时,即为多头注意力(MHA),每个查询头拥有自己独立的KV头。
误差放大机制。在cache-ON模式下,每个共享KV头中的FP16舍入误差会被广播到所有R个查询头。这意味着一个舍入误差会导致R个相关的错误,而不是R个独立的错误,从而相对于MHA放大了差异。R值越大,这种放大效应越强。因此,Mistral(GQA比例4:1, R=4)比LLaMA-2(MHA, R=1)表现出系统性更高的每层漂移,而MQA模型预计会受到最强的影响。
基于残差流的干预。机制可解释性方法通过干预内部激活来研究Transformer的因果架构。诸如因果追踪、ROME【索引15,Locating and editing factual associations in gpt,2022】和路径修补【索引22,Interpretability in the wild: a circuit for indirect object identification in gpt-2 small,2022】等方法通常在残差流(residual stream)上进行操作,该流在各层之间累积计算结果。
KV缓存作为残差流之外的因果变量。如果修补残差状态未能恢复模型行为,那么因果变量就不在残差流中,而很可能在KV缓存中,因为它包含了通过残差流干预无法访问的信息。这提供了一种理论上合理的方法,将差异定位到KV缓存状态,而不是某个特定的层或注意力头。
浮点数不确定性的影响。即使采用贪心采样等确定性解码方法,Transformer的输出也对浮点数不确定性、并行归约顺序、硬件特定核函数等因素敏感。先前已有工作展示了大型语言模型输出的不稳定性和复现性问题,例如关于幻觉和不一致性的综述【索引26,Siren’s song in the ai ocean: A survey on hallucination in large language models,2023】以及跨评估设置的模型复现性和变异性报告(例如,【索引1,Gpt-4 technical report,2023】)。
系统性差异与随机噪声的区别。与之前将这种变异视为随机噪声的研究不同,本文证明了cache-ON和cache-OFF之间的差异是系统性的、可复现的,并且在机制上可归因于FP16的累加顺序——这是执行路径差异的确定性后果,而非随机现象。
构建因果链的五个实验步骤。为了理解KV缓存的差异,仅仅知道差异存在是不够的。我们必须理解它为什么发生、如何传播、位于何处以及何时在行为上变得重要。这五个实验构成了一个因果链,每个实验提出的问题都为下一个实验提供了动机。
1. 步骤1:差异是否存在且是否系统性? 我们首先验证cache-ON和cache-OFF是否会产生不同的输出,并确认这种差异是否在贪心解码中也存在,以排除采样随机性作为原因。如果差异不是真实且结构性的,那就没什么可解释的了。此步骤的度量指标是token差异率和每步的KL散度。
2. 步骤2:差异如何传播? 确认差异后,我们探究第一层的误差如何通过后续层传播。这揭示了效应是局部的还是均匀分布的,以及架构的GQA比例、头维度和注意力类型如何影响传播。
3. 步骤3:原因是什么? 传播特征描述并不能揭示其根源。我们直接检验FP16非结合性假说:如果累加顺序是原因,那么转换为FP32应该能消除差异。这一证伪实验可以排除其他解释。
4. 步骤4:何时对行为产生影响? 确认原因后引出了一个新问题:翻转索引(flip index)和KL散度大小之间存在明显矛盾。解决这个矛盾揭示了数值差异在何种条件下(如接近决策边界)会实际改变模型的行为。
5. 步骤5:差异存在于何处? 最后,我们定位因果变量。如果差异由残差流携带,那么修补隐藏状态应该能恢复无缓存的轨迹。如果修补失败,则差异必然存在于KV缓存状态本身。
实验设计与步骤对应关系。这五个步骤分别对应五个实验:行为特征描述(步骤1)、层漂移分析(步骤2)、根本原因证伪(步骤3)、决策边界分析(步骤4)和因果定位(步骤5)。
自定义生成循环的重要性。解码策略的选择并非偶然,它决定了哪些替代解释可以被排除。我们构建了自己的逐token生成循环,而不是使用model.generate() API,原因有二:(1) 为了在每个解码步骤获取cache-ON和cache-OFF的完整词表logits,以便计算每一步的KL散度;(2) 为了确保两条路径在每一步都使用相同的随机种子,从而使得任何差异都源于分布散度,而非随机采样方差。这两种功能在原生API中均未提供。
argmax操作,cache-ON和cache-OFF之间的所有差异都完全来自上游的FP16 KV缓存计算,而非采样过程。这可以用来分离纯粹的执行路径效应。策略组合的意义。综合来看,这三种策略表明差异并非源于特定的解码机制,而是FP16缓存推理的一个基本属性。所有三种策略都以max_new_tokens=128运行,并记录了每一步的完整KL散度轨迹。
不同指标捕捉不同维度的差异。每个度量指标都捕捉了差异的不同方面,并回答了论证中的一个特定问题。
Token 差异 (Token divergence): 如果生成的token序列在任何点上不同,则运行对被视为有差异(二元结果)。这直接确定了cache-ON和cache-OFF是否生成了不同的文本。这是步骤1的主要行为结果:如果token差异是普遍的,则该效应是结构性的而非偶然的。
KL 散度 (KL divergence): 在每个解码步骤,计算cache-ON分布pt和cache-OFF分布qt在整个词汇表上的KL散度。
其中,$p_t$和$q_t$分别表示在步骤t时cache-ON和cache-OFF的softmax分布。跨步骤的平均KL散度是主要的差异幅度度量。KL散度衡量了在选择token之前,两条路径之间概率质量的转移量。持续高位的KL轨迹表明分布是连续发散的,而不仅仅是波动。这使得在步骤2中进行逐层漂移分析和在步骤3中进行FP32幅度比较成为可能。
JS 散度 (JS divergence): JS散度是$p_t$和$q_t$之间的对称Jensen-Shannon散度:
$$\text{JS}(p_{t} \| q_{t})=\frac{1}{2} \text{KL}(p_{t} \| m_{t})+\frac{1}{2} \text{KL}(q_{t} \| m_{t})$$其中$m_t = \frac{1}{2}(p_t + q_t)$,其值有界于[0, 1]。如果JS散度反映了与KL散度相同的趋势,那么差异模式就是有效的,而不是由KL散度的无界范围引起的数学假象。JS散度作为KL散度发现的一致性检验。
翻转索引 (Flip index): 翻转索引是cache-ON和cache-OFF在token上出现分歧的第一个解码步骤。这是行为差异开始的点,即数值漂移首次在生成的文本中变得可观察。它与平均KL散度的关系在步骤4中进行探讨:早期翻转和高KL散度之间的明显矛盾揭示了是决策边界的邻近度,而非差异幅度,决定了差异何时变得重要。
针对性选择统计检验方法。每种检验方法的选择都是为了回答论证中的一个特定推断性问题,而不是作为更一般的显著性检查。
遵循因果逻辑的五个实验。这五个实验遵循了分析框架(第4节)中建立的因果逻辑。下面描述每个实验及其在论证中的作用。
激活补丁 (Activation Patching)。最后一个问题是差异存在于何处。对于每个模型中KL散度最高的n=600个样例子(从全部700个中按平均KL散度降序选择),在解码步骤0将cache-ON的隐藏状态注入到cache-OFF的前向传递中。选择步骤0是因为这是隐藏状态开始分歧的第一步,预填充(prefill)阶段两条路径的状态是相同的(在所有层$∥h_{on} − h_{off}∥ = 0$)。这个操作逐层进行,并累积地跨越所有0到L层。恢复率的计算公式为:
$$ \text{pct\_recovered} = \frac{\text{KL}_{\text{base}} - \text{KL}_{\text{patched}}}{\text{KL}_{\text{base}}} \times 100 $$这里的负恢复率证实了修补反而增加了差异。所有修补实验每个样例子运行MAX_STEPS=32个解码步骤。
模型:
AutoModelForCausalLM.from_pretrained以FP16精度加载。数据集:
SAMPLE_SEED=42)进行评估。选择该数据集是因为它包含足够长的多步数学问题,涉及多次KV缓存组合,并提供明确的二元正确性标准。每个样本在5个不同的随机种子下运行,总计每个模型-策略组合有3,500次运行。硬件配置:
软件配置:
CUBLAS_WORKSPACE_CONFIG=:4096:8并启用torch.use_deterministic_algorithms(True)以确保在同一GPU上的可复现性。Token差异是普遍现象。在所有31,500对运行中,有缓存(cache-ON)和无缓存(cache-OFF)生成的序列无一例外地都不同(差异率100%)。这一现象在贪心解码下同样存在,排除了采样随机性的影响,证明了这种差异是FP16缓存推理的结构性属性。
执行路径偏向是系统性的,而非随机。在9个实验条件中的8个里,cache-ON模式比cache-OFF模式产生了更多正确的答案,并且差异在统计上显著(经Bonferroni校正的McNemar检验,p < 0.05)。这表明存在系统性的执行路径偏向,而不是随机波动。
KL散度值大、稳定且与模型相关。平均KL散度在10.68到12.83之间,且95%置信区间很窄,表明结果稳定。Gemma模型在所有策略下始终表现出最大的差异。Mistral模型由于其4:1的GQA比例,比LLaMA-2的MHA结构更积极地放大了每步的累积误差。所有条件下平均KL散度均大于10,表明每一步解码都发生了显著的概率质量重新分布。JS散度的趋势与KL散度一致,证实了该发现的有效性。
图1显示,KL散度在全部128个解码步骤中持续保持高位,并未趋向于零,这符合FP16舍入误差在每个解码步骤不断累积的预期。
LLaMA-2和Mistral在第一层出现急剧的差异放大。如图2所示,LLaMA-2的隐藏状态漂移在第1层从3.45跃升至361.96(放大105倍)。Mistral则从25.76跃升至902.36(放大35倍)。这种现象与Mistral的4:1 GQA比例一致,即共享的KV头将FP16舍入误差同时广播到四个查询头,导致了相关联的误差,从而比LLaMA-2的MHA结构更强烈地放大了初始扰动。
Gemma表现出结构上独特的平坦分布。与前两者不同,Gemma没有突然的放大事件,其漂移在所有26层中保持在1.0到1.4之间。尽管其每层漂移值最低,但其总KL散度最大,表明Gemma的差异是均匀分布在所有层中累积的。这种平坦的分布模式与其更大的头维度(256 vs 128)和滑动窗口注意力机制相符,这些机制限制了误差的全局传播,导致差异缓慢地跨层累积。
FP32精度完全消除了差异。当推理切换到FP32精度时,KL散度在所有三个模型中都骤降至约2.5 × 10−8的噪声水平(见图3,表2)。所有样本的token翻转率精确为0.0%。这清晰地证伪了其他可能的原因,证明差异并非源于KV缓存机制本身,而是实现它所用的FP16数据格式。
差异幅度下降超过八个数量级。FP16下的平均KL散度约为10^1,而FP32下的噪声水平约为2.5 × 10−8,两者相差约4 × 10^8倍。这无可辩驳地证明FP16的非结合性是观察到差异的主要且充分的原因。
累积误差由架构决定。平均相对FP16累积误差为0.036%,并且在不同输入序列长度(16-128 tokens)和模型间保持稳定。这表明误差范数由注意力头的维度决定,与输入数据或长度无关,因此差异可以仅从架构预测。
“早期翻转伴随高KL散度”的矛盾现象。数据显示,翻转索引(flip index)与平均KL散度呈负相关(见图4,表3),即翻转得越早的样本,其KL散度越高。这似乎与累积假说(累积更多差异应导致更早翻转)相矛盾。
决策边界邻近度是根本原因。该矛盾可通过决策边界邻近度来解释。接近决策边界的样本,其logit裕度很小,任何微小的扰动都足以跨越阈值导致翻转,因此翻转发生得早。同时,这些样本由于模型本身的不确定性,会将概率质量分散到多个token上,从而导致高KL散度。因此,早期翻转和高KL散度是模型不确定性的共同结果,而非因果关系。在8/9的实验条件下,早翻转样本的KL散度显著高于晚翻转样本。
残差流激活补丁无法恢复KV缓存差异。对残差流的隐藏状态进行激活补丁(activation patching)实验显示,单层修补的总体中位恢复率仅为-0.019%(平均-1.21%),表明修补并未使输出收敛,反而可能在扰乱轨迹。即使同时修补所有层的隐藏状态,所有三个模型的平均恢复率也为负(LLaMA-2: -5.34%, Mistral: -2.56%, Gemma: -2.88%)。
Gemma的部分恢复与其平坦漂移分布一致。Gemma的单层修补中位恢复率为+5.75%,显著高于其他模型。这源于其平坦的层漂移分布,差异均匀累积,使得单层修补能部分重定向局部差异轨迹。
最终定位:差异存在于KV缓存状态。然而,对Gemma进行所有层的累积修补,平均恢复率仍为负。这个最终的无效结果表明,即使修补整个残差流也不足以修复差异。这一定位了因果变量在KV缓存状态本身:FP16舍入误差被写入并保存在KV张量中,影响所有下游注意力计算,而残差流的干预对此无能为力。正确的干预需要直接修补KV缓存张量。
五个实验构成的完整论证链。这五个实验共同构成了一个统一的论证。首先,行为特征描述显示,在FP16下,cache-ON和cache-OFF的推理结果在所有情况下(模型、策略、种子)总是存在差异。其次,根本原因证伪实验确定了原因是FP16的非结合性,因为切换到FP32后差异完全消失。接着,层漂移分析揭示了初始误差的传播方式:在LLaMA-2和Mistral中,误差在第一层急剧放大;而在Gemma中,误差广泛分布于所有层。然后,决策边界分析解释了这种差异何时在行为上变得重要:并非因为差异幅度大,而是因为模型处于决策边界附近。最后,因果定位确定了误差存在的位置:即使修补所有层的残差流也无法恢复,表明误差存在于KV缓存状态本身。
与先前工作的区别。这个机制是确定性的、可复现的、可从架构预测的,并且无法通过后处理纠正。这与先前研究中展示的依赖于采样温度和提示敏感性的随机输出波动【索引4, How is chatgpt’s behavior changing over time?, 2023】、【索引27, Judging llm-as-a-judge with mt-bench and chatbot arena, 2023】,以及最近由【索引25, Understanding and mitigating numerical sources of nondeterminism in LLM inference, 2025】描述的依赖于批量大小和硬件的非确定性有本质区别。本文描述的差异在相同硬件和批量大小下也会出现,甚至在固定种子的贪心解码下也存在,并且可以通过改变算术精度来消除。
差异由架构决定而非输入。差异并非随机现象,而是可以从注意力架构中预测的。约0.036%的平均相对累积误差在不同序列长度和模型中保持平稳,表明每步的误差幅度由注意力头的维度和结构决定,而非输入上下文。
GQA的放大效应。当与GQA结合时,情况变得更糟,因为一个FP16误差会传播到R个查询头,形成相关误差而非独立误差。这导致了LLaMA-2(MHA)和Mistral(GQA)在第一层放大倍数上的差异。Gemma的256维头维度导致了平坦的传播模式和均匀的累积,使其尽管每层漂移最低,但总体KL散度最高。这些都仅依赖于架构,并且在推理前就应可预测。
通过排除法进行定位。因果定位实验表明,差异不存在于残差流中,因为修补残差流无法恢复它。最合理的推断是差异存在于KV缓存状态本身,因为FP16舍入误差在每个解码步骤被写入并携带。这是一个排除法论证:通过直接修补KV缓存张量来证明将是更好的方法,这也是未来最重要的直接扩展工作。当前结果建立了一个必要条件:任何残差流干预都无法恢复这种差异。
核心结论: 本文证明,在标准的FP16推理中,KV缓存与无缓存计算在数值上并不等价。这一普遍存在的假设是错误的。由于FP16浮点数运算的非结合性,cache-ON和cache-OFF两条执行路径具有不同的累加顺序,导致了可复现的、确定性的解码token序列差异。这种差异在所有测试的模型、采样策略和输入中都100%出现。
机制性解释: 通过一系列因果实验,本文揭示了这一差异的完整机制:
1. 根本原因: FP16的非结合性是唯一原因,切换到FP32可完全消除差异。
2. 传播方式: 差异以架构可预测的方式传播,GQA模型在第一层急剧放大,而Gemma则在所有层中广泛分布。
3. 行为影响: 差异的行为影响取决于模型是否接近决策边界,而非差异的绝对大小。
4. 因果位置: 差异存在于KV缓存状态本身,无法通过干预残差流来纠正。
未来展望与启示:
* 直接验证: 最直接的未来工作是直接修补KV缓存张量,将因果定位的排除法论证转变为定量的因果证据。
* 误差追踪: 另一个扩展是直接追踪K和V张量中FP16舍入误差随时间和层次的累积情况。
* 实用解决方案: 探索除昂贵的FP32之外的解决方案,如周期性缓存刷新、平滑KV更新或使用高精度缓存。
* 鲁棒性分析: 通过向KV张量注入噪声来研究模型对数值差异的敏感性阈值。
* 通用性扩展: 量化GQA放大因子,在指令调优模型上进行测试,并研究KV缓存差异与缓存压缩方案的相互作用。
随着GQA和MQA成为主流,本文描述的放大效应将更加普遍,使得缓存推理的正确性成为系统设计者需要更加关注的重要属性。