Fast-weight Product Key Memory

作者: Tianyu Zhao, Llion Jones
机构: Sakana AI

A1 主要贡献

现代语言模型中的序列建模层(或称令牌混合器)是基础组件,可被理解为一种关联记忆。现有架构在存储容量和计算效率之间存在权衡:标准的Softmax注意力机制提供无限的存储容量,但计算成本随序列长度呈二次方增长,变得难以承受;而线性注意力变体虽然效率高(亚二次方复杂度),但其固定的存储容量有限。

本文旨在解决大规模存储与低计算开销之间的矛盾。研究人员认为,一个理想的关联记忆应具备四个关键特性:
1. 键值关联:能够将键与值联系起来。
2. 大容量存储:存储空间巨大,甚至是无限的。
3. 低成本:计算复杂度相对于输入长度是亚二次方的。
4. 检索与记忆:不仅能检索信息,关键是能在任何时候从输入中记忆新的键值对。

乘积键记忆(Product Key Memory, PKM)满足了前三个特性,但它被设计为“慢权重”模块,参数仅在训练期间更新,在推理时保持不变,因此无法快速适应新输入,不满足第四个特性。

本文的主要贡献是提出了快速权重乘积键记忆(Fast-weight Product Key Memory, FwPKM)。通过将静态的PKM改造为动态的“快权重”模块,FwPKM能够在训练和推理时都动态更新其参数。这种设计使其能够作为一个高保真的情景记忆,直接从输入序列中存储“情景”,并将这些记忆跨不同上下文进行传递。

实验表明,FwPKM作为一种有效的情景记忆,能够补充标准模块的语义记忆,在长上下文数据集上显著降低了困惑度。特别是在“大海捞针”(Needle in a Haystack)评估中,尽管FwPKM仅在4K令牌长度的序列上进行训练,但它能成功泛化到128K令牌长度的上下文。

A3 背景知识

2. 乘积键记忆

Top-k键值记忆。标准的键值记忆由一个键矩阵 $K \in \mathbb{R}^{N \times d_k}$ 和一个值矩阵 $V \in \mathbb{R}^{N \times d_v}$ 组成,其中 $N$ 是记忆槽的数量,$d_{\{k,v\}}$ 是隐藏维度。为了在不牺牲计算效率的情况下学习一个大型记忆,一种常见方法是通过Top-k操作【32, Scaling memory-augmented neural networks with sparse reads and writes, 2016, NeurIPS】、【44, Memory networks, 2015, ICLR】来利用稀疏性。对于一个输入查询向量 q,模型为每个记忆槽计算一个得分 $s_i$,即查询与键的内积。然后,Top-k操作 $\mathcal{T}_k$ 选出得分最高的 k 个槽的索引。这些选定的得分通过softmax进行归一化,产生权重 $\{s'_i\}$,最终的检索输出 v̂ 是对应值槽的加权和。

表1 | PKM 和 FwPKM 的比较
表1 | PKM 和 FwPKM 的比较

$$s_{i}=\mathbf{q}^{\top}K_{i}$$

$\mathcal{I}=\mathcal{T}_{k}(\mathbf{s})$

$$\{s_i^\prime\} = \mathrm{softmax}(\{s_j\}_{j \in \mathcal{I}})$$ $$\hat{\mathbf{v}} = \sum_{i \in \mathcal{I}} V_i s_i'$$

乘积键记忆。尽管Top-k操作限制了访问的记忆槽数量,但仍然需要为所有 $N$ 个键行计算得分来找到最佳候选项。这种 $O(N)$ 的复杂度使其无法扩展到巨大的记忆规模(例如 $N = 10^6$)。Lample等人【22, Large memory layers with product keys, 2019, NeurIPS】提出了乘积键记忆(PKM)来解决这个问题。在PKM中,查询向量 q 被分解为两个子查询 q1 和 q2,如下面的公式5所示,其中[·; ·]表示拼接。PKM不使用单个大型键矩阵,而是维护两个较小的子键矩阵 $K^{\{1,2\}}$,每个大小为 $\sqrt{N} \times d_k$。记忆槽排列在一个大小为 $\sqrt{N} \times \sqrt{N}$ 的笛卡尔网格中,索引为(i, j)的槽对应于 $K^1$ 的第 i 个子键和 $K^2$ 的第 j 个子键的交互。该槽的得分定义为其对应子查询得分的总和:$s(i, j) = s_i^1 + s_j^2$。重要的是,在这种形式下,可以在不计算所有 $N$ 个得分的情况下获得该笛卡尔积的Top-k元素。具体做法是,首先独立地为每个子查询找到Top-k个索引,然后在这些选定集合的长度为$k^2$的笛卡尔积中进行搜索。

$$[\mathbf{q}^1 ; \mathbf{q}^2] = \mathbf{q}$$ $$s_i^{\{1,2\}} = \mathbf{q}^{\{1,2\}\top} K_i^{\{1,2\}}$$ $$\mathcal{I}^{\{1,2\}} = \mathcal{T}_k(\mathbf{s}^{\{1,2\}})$$ $$\mathcal{I}=\mathcal{T}_{k}(\{s_{i}^{1}+s_{j}^{2} | i \in \mathcal{I}^{1}, j \in \mathcal{I}^{2}\})$$

PKM检索过程。一旦选定了最终的索引 $\mathcal{I}$,剩余的检索过程(即得分归一化和加权求和)保持不变。

$$\hat{\mathbf{v}}=\operatorname{PKM}(\mathbf{q} ; K, V)=\sum_{i \in \mathcal{I}} V_{i} s_{i}^{\prime}$$

PKM架构的优势。PKM架构提供了一种以低成本构建大规模记忆模块的优雅方式。一个拥有 $10^6$ 个槽的记忆可以使用两个大小为 $10^3$ 的小键矩阵进行索引,仅需要进行 $2 \times 10^3$ 次得分计算。此外,PKM自然地维护了键与值之间的映射关系,使其成为我们想要构建的高容量关联记忆的理想基础。

A2 方法细节

3. 快速权重乘积键记忆

标准的PKM被设计为一个慢权重通道混合器。换言之,其参数通过一个全局目标(下一个词元预测)在整个训练语料库上进行更新,但在推理过程中保持冻结。因此,虽然它可以存储一般的语义知识,但无法适应新的输入或记住即时上下文。

我们提议将PKM转化为快速权重PKM(FwPKM)。我们重新设计该模块,使其在训练和推理时都能通过优化一个局部目标来动态更新其参数。这使得FwPKM能够作为一个高保真的情景记忆,从输入流本身捕获键值关联。

3.1. 快速权重

快权重与慢权重。在标准的神经网络中,知识存储在“慢权重”中——这些参数 $\theta$ 在大规模训练数据集上进行优化,但在训练后被冻结。虽然慢权重能有效地存储数据集中的通用知识,但它们缺乏快速适应新上下文的能力。“快权重”【15, Using fast weights to deblur old memories, 1987】、【36, Learning to control fast-weight memories: An alternative to dynamic recurrent networks, 1992, Neural Comput.】的概念通过引入一组根据每个新输入动态变化的参数来解决这个问题。快权重参数 $\theta$ 可以被看作是情景记忆的存储。

快权重的实现方式。近期的工作如测试时训练(TTT)【39, Learning to (learn at test time): Rnns with expressive hidden states, 2025, ICML】表明,快权重模块可以实现为一个神经模型 $f(\cdot; \theta)$,其参数通过最小化输入序列 $h_1, \dots, h_L$ 的均方误差(MSE)目标来更新。具体来说,对于序列中的一个输入 $h_t$,模型执行一个梯度下降步骤,以最小化从其损坏版本 $\tilde{h_t}$(例如通过低秩投影)重构 $h_t$ 的MSE:

$$\begin{aligned} \begin{aligned} \theta^{\prime} & =\theta-\eta \nabla_\theta \mathcal{L}_{\mathrm{MSE}}(f(\tilde{\mathbf{h}}_t ; \theta), \mathbf{h}_t) \\ & =\theta-\eta \nabla_\theta\|\mathbf{h}_t-f(\tilde{\mathbf{h}}_t ; \theta)\|_2^2, \end{aligned} \end{aligned}$$

其中 $\eta$ 是学习率。通过这个优化过程,快权重 $\theta$ 学会了在训练和推理时对输入序列的信息进行编码。

FwPKM中的快权重。我们遵循TTT的方法,在正向传播过程中通过对一个局部MSE目标执行梯度下降更新来实现快权重。具体来说,我们将键矩阵和值矩阵视为快权重参数 $\theta = \{K, V\}$。这些参数通过块级梯度下降进行更新,以最小化一批查询-值对的重构误差。

查询-值对的生成。对于一个大小为 $C$ 的数据块,我们使用慢权重投影从隐藏状态生成查询和目标值输入 $\{(q_t, v_t)\}_{t=1}^C$。

$$\mathbf{q}_t = \text{Linear}_{\phi}^{q}(\text{RMSNorm}_{\phi}^{q}(\mathbf{h}_t))$$ $$\mathbf{v}_t = \text{Linear}_\phi^v(\text{RMSNorm}_\phi^v(\mathbf{h}_t)).$$

参数更新。模型接着使用当前的快权重 $\theta$ 计算一个预测值 $\hat{v}_t$,并更新权重以最小化预测值与目标值之间的均方误差(MSE)。我们将学习率 $\eta$ 设置为1.0,并将损失乘以一个0.5的常数因子。

$$\begin{aligned} \begin{aligned} \mathbf{\hat{v}}_t &= \text{PKM}(\mathbf{q}_t; \theta) \\ \theta' &= \theta - \sum\nolimits_{t=1}^C \nabla_\theta \tfrac{1}{2} \mathcal{L}_{\text{MSE}}(\mathbf{\hat{v}}_t, \mathbf{v}_t). \end{aligned} \end{aligned}$$

3.2. 记忆优化

3.2.1. MSE损失

优化目标。优化的目标是“重写”记忆,使得用查询 $q_t$ 进行检索时能够得到目标值 $v_t$。我们采用均方误差(MSE)作为目标函数,因为它对于显式记忆重写具有良好的梯度特性。

一步重写。单个样本的损失为 $L_{\text{MSE}}(\hat{v}, v) = \frac{1}{2} \|\mathbf{v} - \hat{\mathbf{v}}\|_2^2$。我们使用0.5的常数因子和1.0的学习率(如公式15所示)并非随意选择,而是为了实现“一步重写”的效果。相对于预测值 $\hat{v}$ 的梯度是 $\nabla_{\hat{\mathbf{v}}} \frac{1}{2} L_{\text{MSE}}(\hat{\mathbf{v}}, \mathbf{v}) = -(\mathbf{v} - \hat{\mathbf{v}})$。因此,一个梯度步骤就直接将预测值更新为目标值。

$$\hat{\mathbf{v}}' = \hat{\mathbf{v}} - 1.0 \cdot (-(\mathbf{v} - \hat{\mathbf{v}})) = \mathbf{v}.$$

这种机制使得FwPKM能够即时记忆新的键值关联。

值矩阵梯度。在实践中,我们优化的是产生预测的值矩阵,相对于第 i 个值矩阵行的梯度由下式给出:

$$\nabla_{V_{i}} \frac{1}{2} \mathcal{L}_{\mathrm{MSE}}(\hat{\mathbf{v}}, \mathbf{v}) = -(\mathbf{v}-\hat{\mathbf{v}}) s_{i}^{\prime} .$$
3.2.2. 损失聚合与梯度整形

块处理中的挑战。上一节展示了单个样本的梯度,但我们是按块处理数据。这导致两个问题:a) 每个预测值 $\hat{v}$ 的每个特征上都会计算一个MSE损失;b) 多个词元可能同时尝试写入同一个值矩阵行。下面我们展示如何将一个块中的MSE损失适当地规约为一个标量损失以进行反向传播,以及如何聚合一个值矩阵行的梯度。

通过求和将MSE损失规约为标量。我们通过对样本和特征维度进行求和来规约一个块中的MSE损失,而不是使用平均。否则,相对于值矩阵元素 $V_{i,j}$ 的梯度将与实际更新信号成比例,比例常数为 $1/(\text{num\_samples} \times d_v)$。

按行贡献加权梯度。我们用“行贡献” $N_i^{\text{read}}$(即值矩阵行 $V_i$ 在当前块中被访问的次数)来缩放行 i 的梯度,具体为乘以 $1/N_i^{\text{read}}$。

$$\nabla_{V_{i}}^{\text{agg}} = \frac{1}{N_{i}^{\text{read}}} \sum\nolimits_{t=1}^{C} \nabla_{V_{i}} \frac{1}{2} \mathcal{L}_{\text{MSE}}(\hat{\mathbf{v}}_{t}, \mathbf{v}_{t}).$$

这种平均策略起到了一种共识机制的作用,确保来自同一块中不同词元的竞争性记忆写入是平衡的,而不是累加的。

按词元重要性加权MSE损失。虽然平均策略缓解了过度的记忆写入,但它平等地对待对同一槽的所有写入。我们通过一个“门控值” $g_t$(在3.4.4节中定义)来加权每个MSE损失,以便在竞争性写入中优先处理重要的更新。这个标量值代表了FwPKM对语言模型影响的强度以及其对应词元的有用性。

无梯度裁剪。与标准训练不同,我们明确避免对快权重更新进行梯度裁剪。由于目标值 $v$ 是无界的,未裁剪的梯度有助于记忆体完全适应目标值的尺度。

3.3. 寻址优化

3.3.1. 边际熵损失

记忆坍塌问题。稀疏记忆存在“记忆坍塌”的问题,即模型只学会利用少数记忆槽。Lample等人【22, Large memory layers with product keys, 2019, NeurIPS】的研究表明,一个512×512=262K个槽的普通PKM,在每个词元有效选择Top-128个槽(4个头×每个头Top-32)的情况下,槽使用率只能达到64.4%。作者提出对查询应用批归一化,将使用率提高到97.9%。然而,我们发现当有效Top-k较小(例如1个头,每个头Top-8)时,查询归一化技术效率不高,而小的有效Top-k对于构建高性能的FwPKM至关重要(见附录B的消融研究)。

边际熵最大化。我们通过优化一个基于边际熵最大化的辅助寻址目标来对抗FwPKM的记忆坍塌。目标是鼓励模型在块的平均水平上均匀地访问所有记忆槽,而不强迫任何单个查询的分布是均匀的。具体来说,对于FwPKM的每个子键集,令 $s'_t \in \mathbb{R}^{\sqrt{N}}$ 表示块中词元 t 在Top-k选择后的归一化查询-键得分,其中未被选择的索引得分为0。我们计算表示块上平均槽使用率的边际分布 $\bar{p}$,并将寻址损失定义为 $\bar{p} \in \mathbb{R}^{\sqrt{N}}$ 的边际熵:

$$ \bar{\mathbf{p}} = \frac{1}{C} \sum_{t=1}^{C} \mathbf{s}'_{t} $$

$$ \mathcal{L}_{\text{addr}} = -H(\bar{\mathbf{p}}) = -\sum_{i=1}^{\sqrt{N}} \bar{p}_{i} \log \bar{p}_{i} . $$

寻址损失的作用。最小化MSE损失优化了FwPKM的值矩阵 $V$ 以存储输入值,而最小化边际熵损失则训练键矩阵 $K$ 以适应输入查询的分布,使得键向量能更有效、更均匀地覆盖查询表示空间。

3.3.2. IDW得分

点积得分的局限性。PKM中的查询-键得分 $s_i$ 是查询与键行 $q^T K_i$ 的点积。然而,键行可以通过改变其范数来为目标查询产生更大的得分,而无需在表示空间中与查询“接近”。

逆距离加权(IDW)得分。逆距离加权(IDW)【25, Inverse distance weighting attention, 2023】得分是点积得分的一种替代方案,它能产生不同的键布局。

$$s_i^{\mathrm{IDW}} = -\log(\epsilon + \|\mathbf{q} - K_i\|_2^2),$$

其中 $\epsilon = 10^{-3}$,遵循原论文。由于使用了欧几里得距离,IDW得分产生的梯度会推动键成为“原型”——即查询簇的质心。我们发现IDW得分比点积得分能带来更好的性能。

3.4. 目标值构建

3.4.1. 值残差

慢权重参数的学习。目标值由一个慢权重网络 $v_t = \text{Linear}_\phi^v(\text{RMSNorm}_\phi^v(h_t))$ 产生。然而,这些慢权重参数仅直接参与快权重MSE损失的计算,其梯度无法到达慢权重。我们从值投影层的输出到FwPKM的输出添加了一个残差连接,这样值投影参数就位于语言模型损失的正向/反向路径上,从而为它们提供学习信号,使其产生对下一个词元预测有用的目标值。

3.4.2. 前瞻值

关联当前键与未来值。与在线性注意力变体中使用短卷积【13, Mamba: Linear-time sequence modeling with selective state spaces, 2024, COLM】、【30, RWKV: reinventing rnns for the transformer era, 2023, Findings of EMNLP】、【46, Gated linear attention transformers with hardware-efficient training, 2024a, ICML】的精神类似,我们在应用块级更新时将查询与前瞻值配对。具体来说,我们稍微修改了公式15中的更新规则,将目标值的时间步下标从 $t$ 改为 $t+1$。这样,FwPKM将每个词元的键与下一个词元的值关联起来,为下一个词元预测提供更有用的信息。

$$\hat{\mathbf{v}}_{t+1}=\operatorname{PKM}\left(\mathbf{q}_{t} ; \theta\right)$$
3.4.3. 目标值归一化

提升训练稳定性。我们还发现在特征维度上对目标值进行z-score归一化是有益的。尽管没有梯度裁剪已经确保了值矩阵 $V$ 能够适应任意尺度的输入,但将目标值约束为均值为0、标准差为1可以提高训练的稳定性。

3.4.4. 门控机制

动态调节记忆使用。下一个词元的预测并不总是依赖于情景记忆,因此我们设计了一个门控机制,让模型可以自由决定从FwPKM输出中提取多少信息。与计算查询和值向量类似,我们将隐藏状态 $h_t$ 输入到一个RMS归一化层和一个线性层来计算一个标量值:

$$g_{t}=\mathrm{Linear}_{\phi}^{g}(\mathrm{RMSNorm}_{\phi}^{g}(\mathbf{h}_{t})),$$

最终输出插值。最终的输出 $o_t$ 是FwPKM输出和值残差之间的插值:

$$\begin{aligned} \begin{aligned} \mathbf{o}_t &= g_t \cdot \mathrm{PKM}(\mathbf{q}_t; \theta) + (1 - g_t) \cdot \mathbf{v}_t \\ &= g_t \cdot \mathbf{\hat{v}}_{t+1} + (1 - g_t) \cdot \mathbf{v}_t. \end{aligned} \end{aligned}$$

3.5. FwPKM 总结

完整流程。我们将所有部分整合在一起,在下文和图1中展示了FwPKM的完整公式。

输入计算。对于块中的每个词元 $t$,慢权重网络计算FwPKM的输入。

$$\mathbf{q}_t, \mathbf{v}_t, g_t = \text{Linear}_{\phi}^{q, v, g}(\text{RMSNorm}_{\phi}^{q, v, g}(\mathbf{h}_t))$$

前向传播。FwPKM的正向传播从时刻 $t$ 的查询预测时刻 $t+1$ 的值,即 $\hat{v}_{t+1} = \text{FWPKM}(q_t; K^{\{1,2\}}, V)$。具体步骤如下(为简洁,除输入 $q_t$ 和输出 $\hat{v}_{t+1}$ 外,省略了时间下标 t):

  1. 查询分解:

    $$[\mathbf{q}^1 ; \mathbf{q}^2] = \mathbf{q}_t$$

  2. 计算子键得分:
    $$s_i^{\{1,2\}} = -\log(\epsilon + \|\mathbf{q}^{\{1,2\}} - K_i^{\{1,2\}}\|_2^2)$$
  3. 选择子键索引:
    $$\mathcal{I}^{\{1,2\}}=\mathcal{T}_{k}\left(\mathbf{s}^{\{1,2\}}\right)$$
  4. 归一化子键得分:
    $${s_{i}^{\{1,2\}'}\} = \text{softmax}(\{s_{j}^{\{1,2\}}\}_{j \in \mathcal{I}})$$
  5. 选择最终索引:
    $$\mathcal{I}=\mathcal{T}_{k}(\{s_{i}^{1}+s_{j}^{2} | i \in \mathcal{I}^{1}, j \in \mathcal{I}^{2}\})$$
  6. 检索值:
    $$\hat{\mathbf{v}}_{t+1} = \sum_{i \in \mathcal{I}} V_i s'_i$$
图1 | FwPKM 架构。图例从上到下依次为:慢权重参数,快权重键参数,快权价值参数,寻址损失,记忆损失。
图1 | FwPKM 架构。图例从上到下依次为:慢权重参数,快权重键参数,快权价值参数,寻址损失,记忆损失。

输出组合。预测的值与值残差使用上述门控值作为权重进行组合。最后,使用一个慢权重的RMS归一化层和线性层来转换输出。

$$\mathbf{o}_{t}=g_{t} \cdot \hat{\mathbf{v}}_{t+1}+\left(1-g_{t}\right) \cdot \mathbf{v}_{t}$$ $$\mathbf{o}_{t}^{\prime}=\text{Linear}_{\phi}^{o}(\text{RMSNorm}_{\phi}^{o}(\mathbf{o}_{t}))$$

参数更新。一旦我们收集了一个块的预测值和目标值,我们就对FwPKM的参数 $\theta = \{V, K^1, K^2\}$ 进行更新。值矩阵使用由门控值加权的MSE损失的整形梯度进行更新。

$$\nabla_{V_{i}}^{\mathrm{agg}}=\frac{1}{N_{i}^{\mathrm{read}}} \sum_{t=1}^{C} \nabla_{V_{i}} \frac{1}{2} g_{t} \mathcal{L}_{\mathrm{MSE}}\left(\hat{\mathbf{v}}_{t}, \mathbf{v}_{t}\right)$$ $$V_i^{\prime}=V_i-\nabla_{V_i}^{\text {agg}}.$$

键矩阵使用基于边际熵的寻址损失的梯度进行更新。

$$\bar{\mathbf{p}}=\frac{1}{C}\sum\nolimits_{t=1}^{C}\mathbf{s}_{t}^{\prime}$$ $$\mathcal{L}_{\mathrm{addr}}=-\sum_{i=1}^{\sqrt{N}} \bar{p}_{i} \log \bar{p}_{i}$$ $$K^{\prime}=K-\nabla_{K} \mathcal{L}_{\mathrm{addr}}, \quad K \in\left\{K^{1}, K^{2}\right\}$$

A4 实验环境

A4 实验结果

困惑度(PPL)评估

图2 | 在 Fineweb-Edu, LC64, 和 LAMBADA 上的困惑度
图2 | 在 Fineweb-Edu, LC64, 和 LAMBADA 上的困惑度
图3 | 在 Fineweb-Edu, LC64, 和 LAMBADA 测试集上,FwPKM的门控值分布。每一行代表一个FwPKM层。
图3 | 在 Fineweb-Edu, LC64, 和 LAMBADA 测试集上,FwPKM的门控值分布。每一行代表一个FwPKM层。

“大海捞针”(NIAH)评估

图4 | 在4K/8K/32K/128K长度测试集上的NIAH准确率结果的堆叠条形图。每个堆叠条显示了{1, 2, 3, 4}-iter NIAH评估的准确率。
图4 | 在4K/8K/32K/128K长度测试集上的NIAH准确率结果的堆叠条形图。每个堆叠条显示了{1, 2, 3, 4}-iter NIAH评估的准确率。

可解释性分析

图5 | 在生成NIAH-4K答案期间,GDN+PKM@6+FwPKM@2,10的FwPKM槽访问示例。模型对草堆进行了3次额外迭代的记忆,即4-iter NIAH。
图5 | 在生成NIAH-4K答案期间,GDN+PKM@6+FwPKM@2,10的FwPKM槽访问示例。模型对草堆进行了3次额外迭代的记忆,即4-iter NIAH。
图6 | GDN+PKM@6+FwPKM@2,10模型在维基百科“Sakana AI”文章词元上的FwPKM门控值。
图6 | GDN+PKM@6+FwPKM@2,10模型在维基百科“Sakana AI”文章词元上的FwPKM门控值。

A7 补充细节

6. 成本分析

模型大小与计算成本。我们在表2中比较了主要模型的参数数量和计算成本。我们报告了FLOPs(浮点运算次数)来衡量所需计算量,以及FLOPS(每秒浮点运算次数),即FLOPs除以运行时间(秒),以额外考虑实现效率。

分析与结论。PKM/FwPKM的稀疏性使其尽管模型尺寸大幅增加,但FLOPs甚至比基线MLP层还要少。然而,FLOPS数据显示PKM/FwPKM组件的运行时间更长。一个主要原因是它们的高效实现存在较大差距。Softmax注意力和线性注意力(如GDN)通过使用FlashAttention【10, FlashAttention: Fast and memory-efficient exact attention with IO-awareness, 2022, NeurIPS】、【8, FlashAttention-2: Faster attention with better parallelism and work partitioning, 2024, ICLR】和FlashLinearAttention【45, FLA: A triton-based library for hardware-efficient implementations of linear attention mechanism, 2024】的内核而速度很快。这项工作的一个重要未来方向是设计更高效的内核,以促进FwPKM模型的更容易扩展和更广泛采用。

表2 | 模型大小与计算成本的比较。
表2 | 模型大小与计算成本的比较。

7. 相关工作

Softmax注意力与线性变体。标准Softmax注意力【42, Attention is all you need, 2017, NeurIPS】是Transformer成功的基石,本质上是一种强大的关联记忆【52, Understanding transformer from the perspective of associative memory, 2025】。然而,其二次复杂度限制了其在极长序列中的应用。为解决此问题,出现了多种高效架构。线性注意力【21, Transformers are rnns: Fast autoregressive transformers with linear attention, 2020, ICML】通过改变关联顺序将复杂度降至线性时间。这一方向包括循环神经网络(RNN)和状态空间模型(SSM)的演进,如Mamba【13, Mamba: Linear-time sequence modeling with selective state spaces, 2024, COLM】和Mamba2【9, Transformers are ssms: Generalized models and efficient algorithms through structured state space duality, 2024, ICML】,以及特定变体如DeltaNet【35, Learning associative inference using fast weight memory, 2021b, ICLR】、【47, Parallelizing linear transformers with the delta rule over sequence length, 2024b, NeurIPS】、Gated DeltaNet【48, Gated delta networks: Improving mamba2 with delta rule, 2025, ICLR】和RWKV7【30, RWKV-7 ”goose” with expressive dynamic state evolution, 2025, COLM】。其他方法如Memory Mosaics【49, Memory mosaics at scale, 2025, NeurIPS】、【50, Memory mosaics, 2025a, ICLR】通过时间维度平滑和分层记忆设计等技术改进Softmax注意力。

快权重与测试时训练。“快权重”概念为统一序列建模提供了一个强有力的视角。该框架源于Schmidhuber【36, Learning to control fast-weight memories: An alternative to dynamic recurrent networks, 1992, Neural Comput.】和Ba等人【1, Using fast weights to attend to the recent past, 2016, NeurIPS】的早期工作,并将线性Transformer视为快权重编程器【34, Linear transformers are secretly fast weight programmers, 2021a, ICML】。最近,测试时训练(TTT)【39, Learning to (learn at test time): Rnns with expressive hidden states, 2025, ICML】、【51, Test-time training done right, 2025b】和Titans【4, Titans: Learning to memorize at test time, 2025c, NeurIPS】复兴了这一范式,它们在推理过程中使用梯度下降显式更新参数,使模型能够记忆当前上下文。理论框架如MIRAS【2, It’s all connected: A journey through test-time memorization, attentional bias, retention, and online optimization, 2025a】和测试时回归【43, Test-time regression: a unifying framework for designing sequence models with associative memory, 2025】将各种序列模型统一在测试时优化和关联记忆的框架下。这些框架为新模型设计指明了几个方向,即记忆架构、记忆规则、记忆保留规则和优化器。我们提出的FwPKM贡献了一种新颖的记忆架构和适应其结构稀疏性的特定记忆规则。

混合架构。认识到不同序列模型的互补优势,近期工作越来越关注结合了二次注意力的高保真检索和线性或循环层效率的混合架构。这包括从头开始训练的混合模型【18, Blending complementary memory systems in hybrid quadratic-linear transformers, 2025, NeurIPS】,以及大规模混合LLM,如Samba【33, Samba: Simple hybrid state space models for efficient unlimited context language modeling, 2025, ICLR】和KimiLinear【40, Kimi linear: An expressive, efficient attention architecture, 2025】。QwenNext也采用了这种交错设计。此外,像人工海马网络(AHN)【11, Artificial hippocampus networks for efficient long-context modeling, 2025】这样的方法探索了微调技术来有效整合这些不同的记忆系统。第4节的实验证明了不同特性记忆之间的相互作用。线性注意力(即GDN)、softmax注意力、慢权重稀疏记忆(即PKM)和快权重稀疏记忆(即FwPKM)的结合实现了一个多功能的记忆系统,在各种任务中表现出色。

记忆模型。除了权重中的隐式知识,显式记忆模块也被探索用于增强存储容量。早期工作包括记忆网络【44, Memory networks, 2015, ICLR】,而最近的研究表明即使是简单的MLP也可以作为记忆【7, Approximating two-layer feedforward networks for efficient transformers, 2023, Findings of EMNLP】。为了在不产生过高成本的情况下扩大容量,稀疏访问机制至关重要。乘积键记忆(PKM)【5, Memory layers at scale, 2025, ICML】、【22, Large memory layers with product keys, 2019, NeurIPS】和PEER【14, Mixture of a million experts, 2024】利用稀疏性来高效访问大规模记忆库。Ultra Sparse Memory【17, Ultrasparse memory network, 2025b, ICLR】、【16, Ultramemv2: Memory networks scaling to 120b parameters with superior long-context learning, 2025a】是一系列扩展PKM架构的工作,具有更具表达力的键和其他改进。我们的工作建立在PKM结构之上,但将其从静态的“慢”记忆转变为动态的“快”记忆。

持续学习与情景学习。更新记忆参数的能力使得持续学习和适应成为可能。Lin等人【24, Continual learning via sparse memory finetuning, 2025】的研究表明,对PKM进行参数高效微调能有效缓解灾难性遗忘。它展示了持续学习的一个维度——通过自监督学习优化稀疏慢权重来更新语义记忆。FwPKM为更新情景记忆开辟了另一个维度,即以在线学习的方式学习快权重。Nested Learning【3, Nested learning: The illusion of deep learning architectures, 2025b, NeurIPS】和TNT【23, Tnt: Improving chunkwise training for test-time memorization, 2025】探索了堆叠多个快权重层(如Titans)并以不同频率应用记忆更新的方向。这种“嵌套”记忆结构使得较快权重中的变化能逐渐被较慢权重消化,并表现出强大的性能。FwPKM维护一个巨大的记忆库,并以较低频率更新以分摊优化成本。设计一个由不同大小的FwPKM和其他记忆组件(如Titans)组成的混合记忆系统,并采用如Nested Learning【3, Nested learning: The illusion of deep learning architectures, 2025b, NeurIPS】中的不同优化策略,是一个很有前景的方向。

A5 结论

本文介绍了快速权重乘积键记忆(FwPKM),一个记忆增强层,它将PKM的大规模稀疏存储与快权重的快速适应性相结合。具体来说,FwPKM将乘积键记忆从一个静态检索模块扩展为一个上下文响应组件,其参数可以在线更新,允许模型将信息写入记忆并在稍后检索。这解决了先前稀疏记忆模块的一个关键限制,即它们无法在推理时有效地整合新证据,并使PKM更适合于相关信息可能被数千个词元分隔开的长上下文设置。

实验上,我们发现FwPKM在远超其训练范围的情况下仍然有效。在4K词元序列上训练的模型可以泛化到128K词元的上下文,同时表现出与强大的情景记忆相一致的行为,这种情景记忆补充了存储在标准层中的语义知识。

同时,仍然存在一些挑战。在线更新引入了额外的计算和选择,如块大小、更新频率和优化超参数。有效地扩展这些更新推动了进一步的系统工作,包括更快的稀疏更新内核和更好的实现策略。在建模方面,未来的工作包括关于其架构、更新规则和保留规则的更强大、更鲁棒的记忆设计。总的来说,通过将稀疏存储与快速、上下文驱动的更新相结合,FwPKM为实现具有多功能且相互补充的记忆组件的语言模型提供了一个有希望的步骤。

A6 附录

A. 详细的训练设置

FwMLP。为了排除与PKM架构无关的实现细节的影响,我们提出了一个基线模型,该模型将FwPKM中的PKM替换为一个SwiGLU-MLP,该MLP维护三个用于上投影、门控和下投影的快权重矩阵(及其偏置)。该基线模型记为FwMLP,通过在块级别上最小化其预测值与目标前瞻值之间的MSE损失来更新其快权重。由于其密集性,我们通过对样本和特征维度进行平均来减少一个块中的MSE损失,并且我们不应用3.2.2节中提到的损失聚合和梯度整形技术。同理,3.3节中的寻址优化也与之无关。

LaCT。我们采用LaCT【51, Test-time training done right, 2025b】的官方实现作为一个强大的TTT【39, Learning to (learn at test time): Rnns with expressive hidden states, 2025, ICML】基线。LaCT架构在每一层都包含一个滑动窗口注意力、一个快权重SwiGLU MLP和一个慢权重SwiGLU MLP。快权重通过带动量的SGD或Muon【19, Muon: An optimizer for hidden layers in neural networks, 2024】进行优化,以最小化点积损失,我们在实验中发现带动量的SGD性能更好。LaCT使用依赖于数据的学习率和L2权重归一化来改善记忆和保留。LaCT(或更广义的TTT)与FwMLP/FwPKM之间一个显著的区别是,LaCT/TTT为小批量中的每个序列维护一组独立的快权重,而FwMLP/FwPKM为所有序列使用一组共享的快权重。我们比较了滑动窗口大小(W)和更新块大小(C)的几种配置。在512W + 512C、512W + 2048C、512W + 4096C和2048W + 2048C中,最佳模型是512W + 2048C。

表3 | 实验中使用的建模和训练超参数。
表3 | 实验中使用的建模和训练超参数。
表3(续) | 实验中使用的建模和训练超参数。
表3(续) | 实验中使用的建模和训练超参数。
表3(续) | 实验中使用的建模和训练超参数。
表3(续) | 实验中使用的建模和训练超参数。

B. 消融研究

实验设置。为了理解第3节中提出的技术的影响,我们基于GDN | FwPKM@2,6,10模型进行了消融实验。以下变体使用相同的流程进行训练和评估:

结果分析。如图7、8和9所示,移除前瞻值对模型性能的损害最大。许多技术带来了轻微的PPL改进,但导致了不太健康的记忆利用率,并随后在不同程度上恶化了NIAH的准确性。

图7 | 消融研究:在Fineweb-Edu, LC64和LAMBADA上的困惑度
图7 | 消融研究:在Fineweb-Edu, LC64和LAMBADA上的困惑度
图8 | 消融研究:在Fineweb-Edu, LC64和LAMBADA测试集上的FwPKM门控值分布。每一行代表一个FwPKM层。
图8 | 消融研究:在Fineweb-Edu, LC64和LAMBADA测试集上的FwPKM门控值分布。每一行代表一个FwPKM层。
图9 | 消融研究:在4K/8K/32K/128K长度测试集上的NIAH准确率结果的堆叠条形图。每个堆叠条显示了{1, 2, 3, 4}-iter NIAH评估的准确率。
图9 | 消融研究:在4K/8K/32K/128K长度测试集上的NIAH准确率结果的堆叠条形图。每个堆叠条显示了{1, 2, 3, 4}-iter NIAH评估的准确率。

C. 更多可视化示例

引言部分的门控值可视化。图10显示了本文引言部分词元级别的FwPKM门控值。

图10 | GDN+PKM@6+FwPKM@2,10模型在本文引言部分词元上的FwPKM门控值。
图10 | GDN+PKM@6+FwPKM@2,10模型在本文引言部分词元上的FwPKM门控值。