You Only Cache Once: Decoder-Decoder Architectures for Language Models

  • 作者: Yutao Sun, Li Dong, Yi Zhu, Shaohan Huang, Wenhui Wang, Shuming Ma, Quanlu Zhang, Jianyong Wang, Furu Wei
  • 机构: Microsoft Research, Tsinghua University

A1 主要贡献

本文针对大型语言模型(LLM)中随着输入序列增长而导致的KV缓存(KV Cache)占用大量GPU内存和预填充(prefilling)延迟过高的问题,提出了一种名为YOCO(You Only Cache Once)的新型解码器-解码器(decoder-decoder)架构。

核心问题与研究目标
标准的仅解码器(decoder-only)Transformer架构在自回归生成时,需要缓存所有先前计算过的键/值(KV)向量,以避免重复计算。当处理长序列时,这部分KV缓存会变得非常庞大,成为推理过程中的内存瓶颈,并导致预填充阶段耗时极长,这限制了长上下文语言模型的实际部署。例如,一个65B模型处理512K个token,KV缓存就需要约86GB的GPU内存。

创新点与主要贡献
1. 提出YOCO架构:这是一种解码器-解码器架构,由一个自解码器(self-decoder)和一个跨解码器(cross-decoder)堆叠而成。自解码器使用高效的自注意力机制(如滑动窗口注意力或门控保持机制)来编码输入并生成一个全局的、共享的KV缓存。随后,跨解码器通过交叉注意力(cross-attention)机制来复用这个KV缓存进行后续的计算。
2. “只缓存一次”实现内存优化:由于只有自解码器的输出被用于生成全局KV缓存,且这个缓存被所有跨解码器层共享,因此整个模型的KV缓存大小与模型层数(L)无关,其内存复杂度从Transformer的O(N * L * D)降低到约O(N * D)。这极大地减少了GPU内存消耗,使得在有限的硬件上部署长上下文模型成为可能。
3. 预填充提前退出(Early Exit)机制:YOCO的计算流程允许在预填充阶段,计算完自解码器后即可提前退出,无需经过跨解码器。这一特性显著加快了预填充速度,例如,对于512K上下文,预填充延迟从180秒缩短到不足6秒,极大改善了用户体验。
4. 可扩展性与高性能:实验证明YOCO在模型大小、训练token数量和上下文长度方面都具有良好的可扩展性。
* 训练规模:将3B规模的YOCO模型扩展到万亿级训练token,其性能与StableLM等主流Transformer模型相当。
* 模型规模:从160M到13B的模型尺寸扩展曲线表明,YOCO与Transformer相比具有竞争力。
* 上下文长度:成功将YOCO扩展到100万(1M)上下文长度,并在“大海捞针”测试中实现了近乎完美的检索准确率。

  1. 门控保持机制(Gated Retention):为自解码器提出了一种名为门控保持(gated retention)的高效自注意力模块,它通过数据控制的门控机制增强了retention网络,实现了训练并行、高性能和低推理成本的统一。

总而言之,YOCO架构通过其独特的“只缓存一次”设计,在保持强大全局注意力能力的同时,显著降低了长上下文语言模型的推理内存占用和预填充延迟,提升了吞吐量,为未来原生支持长序列的大型语言模型提供了一个强有力的候选架构。

图1:我们为大型语言模型提出了一种解码器-解码器架构YOCO,它只缓存键/值一次。YOCO显著减少了KV缓存内存和预填充时间,同时在训练token数量、模型大小和上下文长度方面具有可扩展性。推理成本报告为512K上下文长度,图7-10展示了不同长度的更多结果。
图1:我们为大型语言模型提出了一种解码器-解码器架构YOCO,它只缓存键/值一次。YOCO显著减少了KV缓存内存和预填充时间,同时在训练token数量、模型大小和上下文长度方面具有可扩展性。推理成本报告为512K上下文长度,图7-10展示了不同长度的更多结果。

A2 方法细节

2. You Only Cache Once (YOCO)

YOCO架构概述。本文提出的YOCO架构专为自回归建模(如大型语言模型LLM)设计。如图2所示,该解码器-解码器架构包含自解码器(self-decoder)和跨解码器(cross-decoder)两个部分。具体来说,YOCO由L个块堆叠而成,其中前L/2层为自解码器,其余为跨解码器。对于输入序列x = x1 · · · x|x|,其输入嵌入被打包成$X^0 = [x_1, · · · , x_{|x|}] \in \mathbb{R}^{|x|\times d_{\text{model}}}$,其中$d_{\text{model}}$是隐藏维度。模型首先通过自解码器计算上下文向量表示$X^l = \text{Self-Decoder}(X^{l-1})$, 其中$l \in [1, L/2]$,并利用$X^{L/2}$生成供跨解码器使用的KV缓存$\hat{K}, \hat{V}$。然后,模型计算$X^l = \text{Cross-Decoder}(X^{l-1}, \hat{K}, \hat{V})$, 其中$l \in [ L/2 + 1, L]$,以获得最终的输出向量$X^L$。

图2:解码器-解码器架构概览。自解码器生成全局KV缓存。然后跨解码器使用交叉注意力来复用共享的KV缓存。自解码器和跨解码器都使用因果掩码。整体架构的行为类似于一个仅解码器的Transformer,以自回归方式生成token。
图2:解码器-解码器架构概览。自解码器生成全局KV缓存。然后跨解码器使用交叉注意力来复用共享的KV缓存。自解码器和跨解码器都使用因果掩码。整体架构的行为类似于一个仅解码器的Transformer,以自回归方式生成token。

模块布局与差异。自解码器和跨解码器都遵循与Transformer【索引52,Attention is all you need,2017,NIPS】相似的块布局,即交错的注意力和前馈网络。本文还引入了pre-RMSNorm【索引59,Root mean square layer normalization,2019,NeurIPS】、SwiGLU【索引46,Glu variants improve transformer,2020,arXiv】和分组查询注意力(grouped-query attention)【索引2,Training generalized multi-query transformer models from multi-head checkpoints,2023,arXiv】作为改进。这两个部分的主要区别在于注意力模块:自解码器(2.1节)使用高效的自注意力机制(如滑动窗口注意力),而跨解码器(2.2节)则使用全局交叉注意力来关注由自解码器输出产生的共享KV缓存。

2.1 Self-Decoder

自解码器的计算过程。自解码器接收token嵌入$X^0$作为输入,并计算出中间向量表示$M = X^{L/2}$。其计算过程如下列公式所示:

$$\begin{aligned} \begin{aligned} Y^{l} & =\operatorname{ESA}\left(\operatorname{LN}\left(X^{l}\right)\right)+X^{l} \\ X^{l+1} & =\operatorname{SwiGLU}\left(\operatorname{LN}\left(Y^{l}\right)\right)+Y^{l} \end{aligned} \end{aligned}$$

其中,$ESA(\cdot)$代表高效自注意力(efficient self-attention),$SwiGLU(X) = (\text{swish}(XW_G) \odot XW_1)W_2$,而$LN(\cdot)$使用RMSNorm【索引59,Root mean square layer normalization,2019,NeurIPS】。高效自注意力机制中也使用了因果掩码(Causal masking)。

高效自注意力的关键特性。高效自注意力模块的关键特性是其推理内存为$O(1)$,即KV缓存的数量是恒定的。例如,滑动窗口注意力【索引5,Generating long sequences with sparse Transformers,2019,OpenAI Blog】的缓存大小取决于窗口大小,而不是输入长度。关于高效自注意力模块的更多设计选择(例如,门控保持机制)将在第3节中详细介绍。

2.2 Cross-Decoder

全局KV缓存的生成。首先,自解码器的输出$X^{L/2}$被用来为跨解码器生成全局KV缓存$\hat{K}$和$\hat{V}$:

$$\hat{K}=\operatorname{LN}(X^{L/2})W_K, \quad \hat{V}=\operatorname{LN}(X^{L/2})W_V$$

其中$W_K, W_V \in \mathbb{R}^{d\times d}$是可学习的权重矩阵。然后,跨解码器层堆叠在自解码器之后,以获得最终的输出向量$X^L$。这些KV缓存$\hat{K}$和$\hat{V}$被所有$L/2$个跨解码器层复用。

跨解码器的计算过程。跨解码器的计算过程如下列公式所示:

$$\begin{aligned} \begin{aligned} \hat{Q}^l &= \text{LN}(X^l)W_Q^l \\ Y^l &= \text{Attention}(\hat{Q}^l, \hat{K}, \hat{V}) + X^l \\ X^{l+1} &= \text{SwiGLU}(\text{LN}(Y^l)) + Y^l \end{aligned} \end{aligned}$$

其中Attention(·)是标准的多头注意力【索引52,Attention is all you need,2017,NIPS】,而$W_Q^l \in \mathbb{R}^{d\times d}$是一个可学习的矩阵。交叉注意力中也使用了因果掩码。由于交叉注意力与分组查询注意力【索引2,Training generalized multi-query transformer models from multi-head checkpoints,2023,arXiv】兼容,我们可以进一步节省KV缓存的内存消耗。在获得$X^L$后,一个softmax分类器执行下一个token的预测。

2.3 Inference Advantages

推理优势概述。除了具有竞争力的语言建模结果外,YOCO还显著降低了服务成本并提升了推理性能。详细的推理比较将在第4.4节中报告。

节省GPU内存并服务更多Token。表1比较了Transformer和YOCO的内存复杂度。具体来说,由于全局KV缓存被复用,并且高效自注意力只需要恒定的缓存,YOCO的缓存数量是$O(N + CL)$,其中N是输入长度,C是一个常数(例如,滑动窗口大小),L是层数。对于长序列,CL远小于N,因此大约需要$O(N)$的缓存,即“只缓存一次”。相比之下,Transformer解码器在推理过程中必须存储$N \times L$个键和值。因此,与Transformer解码器相比,YOCO为缓存大约节省了L倍的GPU内存。由于推理能力的瓶颈变成了KV缓存(图7b),我们的方法使我们能够在不耗尽GPU内存的情况下服务更多的token。增加的批处理大小也有利于提高推理吞吐量。

表1:KV缓存的推理内存复杂度。N,L,D分别代表序列长度、层数和隐藏维度。
表1:KV缓存的推理内存复杂度。N,L,D分别代表序列长度、层数和隐藏维度。

减少预填充时间并提高吞吐量。如图3所示,由于跨解码器复用自解码器的输出,我们可以在预填充阶段进入跨解码器之前提前退出。这种计算依赖的有趣特性极大地加快了预填充速度。首先,只需要一半的层进行前向计算,这意味着至少可以减少一半的预填充延迟。其次,自解码器的高效注意力模块通常速度很快。以512K上下文长度为例,我们可以将预填充延迟从180秒(使用Flash-Decoding和内核融合等优化推理的Transformer)减少到不到6秒(图9)。即使对于32K长度,YOCO在预填充时间上也有大约三倍的加速。表2比较了Transformer和YOCO注意力模块的预填充时间复杂度。

图3:YOCO推理过程。预填充:并行编码输入token。生成:逐个解码输出token。该计算流程使得预填充可以提前退出而不改变最终输出,从而显著加快预填充阶段。
图3:YOCO推理过程。预填充:并行编码输入token。生成:逐个解码输出token。该计算流程使得预填充可以提前退出而不改变最终输出,从而显著加快预填充阶段。
表2:注意力模块的预填充时间复杂度。N,L,D与上文相同。
表2:注意力模块的预填充时间复杂度。N,L,D与上文相同。

3. Design Choices of Self-Decoder

自解码器的设计选择。我们可以为自解码器选择各种高效的自注意力方法。只要该模块仅需要恒定的推理内存,自解码器的缓存内存复杂度就取决于层数。此外,一个好的模块选择可以改善训练和部署成本。在这项工作中,我们使用门控保持机制(gated retention)(第3.1节)或滑动窗口注意力(sliding-window attention)(第3.2节)。

3.1 Gated Retention

门控保持机制(gRet)概述。门控保持(gRet,也称为gRetNet或RetNet-3)通过一种数据依赖的门控机制增强了保持网络(retention)【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】,它同时实现了序列建模的训练并行性、良好性能和低推理成本。在实验中,我们使用gRet作为默认的高效自注意力模块。该方法统一了并行、循环和分块循环三种计算范式。这三种表示是等价的,可以获得相同的计算结果。训练过程通常使用并行或分块循环范式,而推理阶段可以采用循环范式以实现恒定的KV内存。下面我们描述这三种表示。

并行表示。门控保持机制的定义如下:

$$Q = (XW_Q) \odot \Theta, \quad K = (XW_K) \odot \bar{\Theta}, \quad V = XW_V, \quad \Theta_n = e^{in\theta}$$

$$\begin{aligned} \gamma = \text{sigmoid}(XW_{\gamma})^{1/\tau}, \quad D_{nm} = \begin{cases} \displaystyle\prod_{i=m+1}^{n} \gamma_i, & n \ge m \\ 0, & n < m \end{cases} \end{aligned}$$
$$\text{gRet}(X) = (QK^{\intercal} \odot D)V$$

其中$W_Q, W_K, W_V \in \mathbb{R}^{d\times d}$和$W_\gamma \in \mathbb{R}^{d\times 1}$是可学习的权重,温度项$\tau$鼓励$\gamma$趋近于1以获得更好的记忆能力【索引56,Gated linear attention transformers with hardware-efficient training,2023,arXiv】。数据控制的衰减是逐头的(head-wise)【索引20,Gateloop: Fully data-controlled linear recurrence for sequence modeling,2023,arXiv】,而不是逐元素的(element-wise),这样计算可以充分利用NVIDIA的张量核心(tensor cores)。关于其他设计的更多细节,请参考【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】。

循环表示。与公式(4)等价,门控保持的输出可以通过循环方式计算。对于第n个时间步,输出通过以下方式获得:

$$\begin{aligned} \begin{aligned} S_{n} &= \gamma_{n} S_{n-1} + K_{n}^{\intercal} V_{n} \\ \operatorname{gRet}(X_{n}) &= Q_{n} S_{n}, \quad n=1, \cdots, |x| \end{aligned} \end{aligned}$$

其中$Q, K, V, \gamma$与公式(4)中相同。在自回归推理期间,自解码器维护$S_n$作为中间状态,以实现高效生成。

分块循环表示。分块表示是循环和并行表示的统一形式。给定块大小B,输出逐块计算。计算分为块内和跨块两部分。用[i]表示第i个块,即$x_{[i]} = x_{(i-1)B+1}, \cdots, x_{iB}$,我们按如下方式计算第i个块:

$$\beta_{(i-1)B+j} = \prod_{k=(i-1)B+1}^{(i-1)B+j} \gamma_k, \quad D_{[i]}(j,k) = \frac{\beta_{(i-1)B+k}}{\beta_{(i-1)B+j}} \text{ if } j \leq k \text{ else } 0$$

$$R_i = K_{[i]}^\intercal (V_{[i]} \odot \frac{\beta_{iB}}{\beta_{[i]}}) + \beta_{iB} R_{i-1}, \quad \beta_{[i]}(j,k) = \beta_{(i-1)B+j}$$
$$\text{gRet}(X) = \underbrace{(Q_{[i]} K_{[i]}^\intercal \odot D_{[i]}) V_{[i]}}_{\text{Inner-Chunk}} + \underbrace{(Q_{[i]} R_{i-1}) \odot \beta_{[i]}}_{\text{Cross-Chunk}}$$

其中$R_i$是第i个块的中间状态,$\beta$总结了数据控制的衰减$\gamma$。附录B中的证明显示了这些计算范式之间的等价性。分块范式结合了并行和循环的最佳特性,即与完全并行计算相比节省了FLOPs,并与循环计算相比减少了迭代次数。在训练和预填充阶段,分块表示提高了吞吐量并减少了GPU内存消耗。

多头门控保持。与多头注意力【索引52,Attention is all you need,2017,NIPS】和多尺度保持【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】类似,我们将门控保持应用于每个头,并将输出组合在一起:

$$\begin{aligned} \begin{aligned} \text{head}_i &= \text{gRet}(X) \\ Y &= \text{GroupNorm}_h(\text{Concat}(\text{head}_1, \cdots, \text{head}_n)) \\ \text{MHGR}(X) &= (\text{swish}(XW_G) \odot Y)W_O \end{aligned} \end{aligned}$$

其中$W_G, W_O \in \mathbb{R}^{d\times d}$是可学习的矩阵,GroupNorm【索引54,Group normalization,2018,ECCV】对每个头进行归一化【索引55,Magneto: A foundation Transformer,2023,ICML】。我们还应用swish门来增加非线性【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】。

3.2 Sliding-Window Attention

滑动窗口注意力。滑动窗口注意力【索引5,Generating long sequences with sparse Transformers,2019,OpenAI Blog】将注意力范围限制在一个固定的窗口大小C内。相比之下,普通的Transformer解码器会关注所有之前的token。在推理过程中,KV缓存的内存复杂度可以从$O(N)$降低到$O(C)$,即内存使用量是恒定的,而不是随序列长度增加而增加。与多头自注意力【索引52,Attention is all you need,2017,NIPS】类似,我们通过以下方式计算滑动窗口注意力的输出:

$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$

$$\text{head}_i = \text{softmax}(Q_{[i]} K_{[i]}^\intercal + B)V$$
$$\begin{aligned} B_{ij} = \begin{cases} 0, & i - C < j \le i \\ -\infty, & \text{otherwise} \end{cases} \end{aligned}$$
$$Y = \text{Concat}(\text{head}_1, \dots, \text{head}_h)$$
$$\text{SWA}(X) = YW_O$$

其中$W_Q, W_K, W_V, W_O \in \mathbb{R}^{d\times d}$是可学习的矩阵,窗口因果掩码$B$控制每个查询只关注距离小于C的先前键。该模块也应用了预归一化(pre-normalization)和残差连接(residual connection)。

A4 实验环境

  • 模型架构

    • YOCO-3B:用于语言建模评估,隐藏层大小3072,26层,24个查询头,8个键/值头(使用GQA),非嵌入参数量2.8B。自解码器使用门控保持(gated retention)。
    • 扩展模型:为验证可扩展性,训练了160M, 400M, 830M, 1.4B, 2.7B, 6.8B, 13B等多种尺寸的模型。对比模型包括Llama架构的Transformer、带滑动窗口注意力的YOCO(YOCOSWA,窗口大小1024)和带门控保持的YOCO(YOCOgRet)。
    • 长上下文模型:将YOCO-3B模型上下文扩展至1M。
  • 数据集与任务

    • 预训练:使用与StableLM-3B-4E1T【索引49,StableLM 3B 4E1T】类似的精选语料库,总计训练了1.6T token。
    • 下游任务评估:使用LM Eval Harness【索引16,A framework for few-shot language model evaluation,2023】框架在ARC-C, ARC-E, BoolQ, Hellaswag, OBQA, PIQA, Winogrande, SciQ等任务上进行零样本性能评估。
    • 长上下文评估
      • Needle-In-A-Haystack:单针和多针(Multi-Needle)检索测试,评估在长达1M的文本中检索信息的能力。
      • 语言建模困惑度:在超过1M token的书籍和代码仓库级别的数据上评估长序列的负对数似然(NLL)。
      • ZeroSCROLLS:在16K长度的上下文中评估模型在长文本理解基准测试上的性能。
  • 硬件配置

    • GPU: H100-80GB GPU卡。
  • 软件配置

    • 分词器: tiktoken-cl100k_base。
    • 优化器: AdamW【索引22,Decoupled weight decay regularization,2019,ICLR】,不同实验设置了不同的beta值和学习率。
    • 核心库与技术
      • 使用了分组查询注意力(GQA)【索引2,Training generalized multi-query transformer models from multi-head checkpoints,2023,arXiv】。
      • 推理优化采用了Flash-Decoding【索引10,Flash-Decoding for long-context inference,2023,Stanford CRFM Blog】和内核融合。
      • 门控保持机制的实现基于Triton【索引50,Triton: an intermediate language and compiler for tiled neural network computations,2019,MLPL】内核。
      • 长序列训练利用了CUBE,一个内部版本的SuperScaler【索引24,SuperScaler: Supporting flexible DNN parallelization via a unified abstraction,2023】。

A4 实验结果

4.1 语言建模评估

  • 实验内容:训练了一个3B参数的YOCO模型,训练token数量达到1.6万亿,并与OpenLLaMA-v2-3B、StableLM-base-alpha-3B-v2和StableLM-3B-4E1T等强大的Transformer模型在多个下游任务上进行零样本性能比较。
  • 实验结果:如表3所示,YOCO-3B在1T和1.6T token的训练检查点上,均取得了与这些经过良好调优的Transformer模型相当甚至更好的性能。例如,在1.6T token训练后,YOCO-3B的平均分(Avg)为0.636,优于StableLM-3B-4E1T。将上下文扩展到1M后,性能进一步提升至0.645。
  • 分析结论:结果表明YOCO架构在标准语言模型基准上具有很强的竞争力,并且能够随着训练token数量的增加而有效扩展其性能。
表3:与之前训练良好的Transformer语言模型 [TBMR, Tow, GL23] 在Eval Harness [GTA+23] 上的结果比较。我们将3B模型扩展到1.6万亿训练token。StableLM-3B-4E1T的1T和1.6T结果来自其技术报告 [TBMR]。YOCO-3B-1M被扩展到1M token的上下文长度。
表3:与之前训练良好的Transformer语言模型 [TBMR, Tow, GL23] 在Eval Harness [GTA+23] 上的结果比较。我们将3B模型扩展到1.6万亿训练token。StableLM-3B-4E1T的1T和1.6T结果来自其技术报告 [TBMR]。YOCO-3B-1M被扩展到1M token的上下文长度。

4.2 与Transformer的可扩展性比较

  • 实验内容:比较了Llama Transformer、带门控保持的YOCO (YOCOgRet) 和带滑动窗口注意力的YOCO (YOCOSWA) 在160M到13B参数范围内的模型大小扩展性。使用验证集损失作为评估指标。
  • 实验结果:如图4所示,YOCO架构(包括YOCOgRet和YOCOSWA)的验证损失随着模型参数量的增加而稳步下降,其扩展曲线与Llama优化的Transformer架构相当。其中,YOCOgRet的表现优于Transformer和YOCOSWA。
  • 分析结论:YOCO架构在模型大小方面表现出良好的可扩展性。YOCOgRet的优势可能来自于其混合了注意力和保持机制的架构,这两种机制的归纳偏置是互补的。
图4:语言模型损失随着模型大小(从160M到13B)的扩展而降低。
图4:语言模型损失随着模型大小(从160M到13B)的扩展而降低。

4.3 长上下文评估

  • 实验内容:将YOCO-3B模型的上下文长度逐步扩展到1M,并在“大海捞针”检索任务、多针检索任务以及长序列语言建模困惑度上进行评估。
  • 实验结果

    • 大海捞针 (Needle In A Haystack):如图5所示,YOCO-3B-1M在1M上下文长度的测试中取得了近乎完美的检索准确率,表明其强大的长上下文建模能力。
    • 多针检索:如表4所示,在128K上下文长度的多针检索测试中,3B的YOCO-3B-1M性能与参数量两倍多的7B LWM-1M-text相当,并优于YaRN-Mistral-128K、MiniCPM-128K和ChatGLM3-128K等其他长上下文模型。
    • 长序列困惑度:如图6所示,在书籍和代码数据上,模型的负对数似然(NLL)随着上下文长度的增加而持续下降,证明YOCO能有效利用长距离依赖信息。
  • 分析结论:YOCO架构能够有效扩展到超长上下文(1M tokens),并在需要长距离信息检索和理解的任务上表现出色。

图5:在1M长度下的“大海捞针”测试结果。
图5:在1M长度下的“大海捞针”测试结果。
表4:多针检索准确率。N表示针的数量。N=1是作为参考的单针检索,N>1表示多针测试。评估在128K长度下进行,因为大多数之前的长上下文模型都是用这个长度进行微调的。
表4:多针检索准确率。N表示针的数量。N=1是作为参考的单针检索,N>1表示多针测试。评估在128K长度下进行,因为大多数之前的长上下文模型都是用这个长度进行微调的。
图6:在书籍和仓库级别代码上的累积平均负对数似然。我们过滤了长度超过1M token的验证样本。YOCO在更长的上下文中取得了改进的性能,即利用长距离信息进行语言建模。
图6:在书籍和仓库级别代码上的累积平均负对数似然。我们过滤了长度超过1M token的验证样本。YOCO在更长的上下文中取得了改进的性能,即利用长距离信息进行语言建模。

4.4 推理优势

  • 实验内容:对YOCOgRet和优化的Transformer(使用GQA、Flash-Decoding等)在GPU内存占用、预填充延迟和吞吐量方面进行性能剖析。
  • 实验结果

    • GPU内存:如图7a和7b所示,YOCO显著降低了推理内存消耗,特别是KV缓存部分。在1M上下文长度下,YOCO的总内存占用仅为12.4GB,是Transformer的1/9.4。图8显示,对于65B模型,YOCO的KV缓存内存可节省约80倍。
    • 预填充延迟:如图9所示,YOCO的预填充时间呈线性增长,而Transformer呈二次方增长。在512K上下文下,YOCO将延迟从180秒降至不到6秒(加速30.3倍);在1M上下文下,加速比达到71.8倍;即使在32K的短上下文,也有2.87倍的加速。
    • 吞吐量:如图10所示,由于预填充时间的大幅缩短和内存节省带来的更大批处理量,YOCO在所有上下文长度上都实现了更高的吞吐量。在512K上下文时,YOCO的吞吐量是Transformer的9.6倍(43.1 token/s vs 4.5 token/s)。
  • 分析结论:YOCO在推理效率上具有数量级的优势,极大地降低了长上下文模型的部署成本,并显著改善了用户体验。

图7:推理过程中的GPU内存消耗。 (a) Transformer和YOCO在不同长度下的推理内存。(b) 1M上下文长度下的内存消耗分解。
图7:推理过程中的GPU内存消耗。 (a) Transformer和YOCO在不同长度下的推理内存。(b) 1M上下文长度下的内存消耗分解。
图8:不同模型大小下每个token的KV缓存GPU内存消耗。YOCO在模型越大时节省得越多。
图8:不同模型大小下每个token的KV缓存GPU内存消耗。YOCO在模型越大时节省得越多。
图9:不同长度的预填充延迟,即在生成第一个token之前编码给定输入提示的时间。Transformer的时间呈二次增长,而YOCO呈线性增长。即使对于如32K的短输入长度,YOCO仍可加速2.87倍。
图9:不同长度的预填充延迟,即在生成第一个token之前编码给定输入提示的时间。Transformer的时间呈二次增长,而YOCO呈线性增长。即使对于如32K的短输入长度,YOCO仍可加速2.87倍。
图10:Transformer和YOCO在不同上下文长度下的推理吞吐量。
图10:Transformer和YOCO在不同上下文长度下的推理吞吐量。

A5 结论

本文提出了一种用于大型语言建模的解码器-解码器架构(YOCO)。与Transformer相比,YOCO在实现显著更优的推理效率的同时,保持了具有竞争力的性能。实验结果表明,YOCO在各种设置下都取得了良好的效果,包括扩展训练token数量、扩展模型大小以及将上下文长度扩展到1M token。性能剖析结果也显示,YOCO将推理效率提升了几个数量级,尤其是在长序列建模方面。

未来工作展望
1. YOCO + BitNet + Groq的结合:Groq通过将所有计算置于SRAM中实现极高吞吐量,但受限于内存容量。YOCO可以减少KV缓存内存,而BitNet可以减少模型权重内存。三者结合有望将LLM的部署成本降低几个数量级。
2. YOCO应用于多模态大模型:YOCO的布局天然支持使用多个自解码器,其交叉注意力层非常适合多模态融合。自解码器的因果依赖性也完美契合流式视频处理,可用于构建异步多模态大模型,避免不同数据流相互阻塞,这对于机器人等实时应用至关重要。
3. 优化KV缓存模块机制:YOCO架构明确了KV缓存模块,为开发原生内存机制开辟了新机会。
* 可以集成缓存压缩机制以获得更紧凑的内存。
* 可以构建索引以实现高效的键值检索,由于YOCO复用缓存,只需维护一个索引。
* 解耦的建模方式支持预缓存上下文,这对于原生RAG和LLM原生搜索引擎具有潜在价值。

A6 附录

A. 用于YOCO长序列训练的块并行(Chunk Parallelism)

块并行策略。为了减少通信频率并加速长序列训练,我们为YOCO引入了块并行。当训练序列极长时,将长序列划分到不同设备上至关重要【索引25,Sequence parallelism: Making 4d parallelism possible,2021,arXiv】、【索引11,Longnet: Scaling transformers to 1,000,000,000 tokens,2023,arXiv】。然而,总吞吐量往往受限于GPU通信【索引27,Ring attention with blockwise transformers for near-infinite context,2023,arXiv】。YOCO的跨解码器设计解耦了自注意力依赖,同时保留了建模能力,为分布式长序列训练带来了独特的优势。在自解码器中,依赖关系仅存在于相邻设备之间,因此通信量相对较小。在跨解码器中,all-gather操作仅对KV缓存触发一次,而不是在每一层都进行通信。这种硬件友好的架构为分布式长序列训练提供了更大的灵活性。

图11:在两个GPU设备上进行YOCO训练的块并行。训练策略是将序列划分为不同的块。M表示中间表示XL/2,即自解码器的输出。跨解码器中的键和值只收集一次。
图11:在两个GPU设备上进行YOCO训练的块并行。训练策略是将序列划分为不同的块。M表示中间表示XL/2,即自解码器的输出。跨解码器中的键和值只收集一次。

B. 门控保持机制的分块表示

循环与分块表示的等价性证明。本节阐述了门控保持机制的循环表示和分块循环表示之间的等价性。对于输出$O_n$,n可以被拆分为$n = kB + r$,其中B是块大小。通过一系列推导,证明了分块计算(包含块内计算和跨块计算)的结果与原始的循环计算或并行计算是完全等价的。

$$O_n = \sum_{m=1}^{n} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m = \sum_{m=kB+1}^{n} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m + \sum_{m=1}^{kB} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m$$ $$\sum_{m=kB+1}^{n} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m = (Q_n K_{kB+1:n}^\intercal \odot \Gamma_{kB+1:n}) V_{kB+1:n}$$ $$\begin{aligned} \begin{aligned} \sum_{m=1}^{kB} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m &= (Q_n \prod_{i=kB+1}^{n} \gamma_i) \sum_{c=0}^{k-1} \sum_{m=1}^{B} (K_{m+cB}^\intercal V_{m+cB} \prod_{i=m+cB+1}^{(c+1)B} \gamma_i) \prod_{i=(c+1)B+1}^{kB} \gamma_i \\ &= (Q_n \prod_{i=kB+1}^{n-1} \gamma_i) \sum_{c=1}^{k} (K_{[c]}^\intercal (V_{[c]} \odot \zeta_{[c]})) \prod_{i=c+1}^{k} \alpha_i \\ &= (Q_n \prod_{i=kB+1}^{n-1} \gamma_i) R_{i-1} \end{aligned} \end{aligned}$$ $$R_{i}=K_{[i]}^{\intercal}(V_{[i]} \odot \zeta_{[i]})+\alpha_{i} R_{i-1}$$ $$O_{[n]}=\sum_{m=k B+1}^{[n]} \beta_{[n]} Q_{[n]} K_m^{\top} V_m+\sum_{m=1}^{k B} \beta_{[n]} Q_{[n]} \prod_{i=m+1}^n \gamma_i K_m^{\top} V_m$$

$$\sum_{m=k B+1}^{[n]} \beta_{[n]} Q_{[n]} K_m^{\top} V_m=(Q_{[n]} K_{[n]}^{\top} \odot D_{[n]}) V_{[n]}, \quad D_{[n]}(j, k)=\frac{\beta_{(n-1) B+k}}{\beta_{(n-1) B+j}} \text { if } j \leq k \text { else } 0$$
$$\sum_{m=1}^{k B} \beta_{[n]} Q_{[n]} \prod_{i=m+1}^n \gamma_i K_m^{\top} V_m=\beta_{[n]} Q_{[n]} R_{i-1}, \quad R_i=K_{[i]}^{\top}(V_{[i]} \odot \frac{\beta_{i B}}{\beta_{[i]}})+\beta_{i B} R_{i-1},$$
$$O_{[n]}=\underbrace{(Q_{[n]} K_{[n]}^{\top} \odot D_{[n]}) V_{[n]}}_{\text {Inner-Chunk }}+\underbrace{(Q_{[n]} R_{n-1}) \odot \beta_{[n]}}_{\text {Cross-Chunk }}$$

C. YOCO-3B 的超参数

YOCO-3B模型配置。本节描述了4.1节中使用的超参数。隐藏维度为3072,层数为26,查询头数量为24,键/值头数量为8(使用GQA)。不含嵌入层的总参数量为2.83B。训练批次大小为4M tokens,训练长度为4096。优化器为AdamW【索引22,Decoupled weight decay regularization,2019,ICLR】,$\beta = (0.9, 0.95)$。最大学习率为$3.2 \times 10^{-4}$,有1000个warmup步骤,并线性衰减至$1.28 \times 10^{-5}$,总调度设置为5T tokens。

表5:4.1节中YOCO-3B模型使用的超参数。
表5:4.1节中YOCO-3B模型使用的超参数。

D. 扩展曲线的超参数

扩展性实验配置。本节描述了4.2节中使用的超参数。表6报告了不同模型大小所使用的隐藏维度、层数和头数。门控保持的头维度设为256。为了对齐参数量,Transformer的FFN大小为$8/3 d$,而YOCO的FFN大小为$3d$。训练长度为2048,批次大小为0.25M tokens。优化器为AdamW,学习率和warmup步数根据模型大小有所不同。模型训练了40k步,即10B tokens。

表6:4.2节中用于扩展曲线的模型大小和超参数。
表6:4.2节中用于扩展曲线的模型大小和超参数。

E. 长度扩展的超参数

长上下文训练配置。在4.3节中,我们逐步将上下文长度扩展到1M tokens。长度调度为64K, 256K, 和1M。我们对长度超过训练长度的文档进行上采样。表7显示了我们在每个阶段使用的不同RoPE $\theta$和学习率。

表7:4.3节中用于长度扩展的超参数。
表7:4.3节中用于长度扩展的超参数。

F. 门控保持的伪代码

门控保持的三种计算范式伪代码。本节提供了门控保持(3.1节)三种计算范式的伪代码。并行实现(Parallel implementation)能够充分利用GPU进行并行训练。循环范式(Recurrent paradigm)可实现低成本推理。分块保持(Chunkwise retention)结合了上述两者的优点(即在块内并行,在块间循环),对于长序列具有线性内存复杂度。

# def RecurrentRetention(
# q, k, v, # bsz * num_head * dim
# past_kv, # bsz * num_head * dim * dim
# gt # bsz * num_head * 1 * 1
# ):
# gt = F.logsigmoid(gt) / gate_logit_normalizer
# current_kv = gt.exp() * past_kv + k.unsqueeze(-1) * v.unsqueeze(-2)
# output = torch.sum(q.unsqueeze(-1) * current_kv, dim=-2)
# output = group_norm(output)
# return output, current_kv

# def ChunkwiseRetention(
# q, k, v, # bsz * num_head * chunk_size * dim
# past_kv, # bsz * num_head * dim * dim
# gt): # bsz * num_head * chunk_size
# gt = F.logsigmoid(gt).cumsum(-1) / gate_logit_normalizer
# cross_retention = (q @ past_kv) * gt[..., None].exp()
# inner_retention = ParallelRetention(q, k, v, gt)
# retention = inner_retention + cross_retention
# output = group_norm(retention)
# value_decay = (-gt + gt[:, :, :, -1, None]).exp()[..., None]
# chunk_decay = gt[..., -1].exp()
# current_kv = chunk_decay * past_kv + k.transpose(-1, -2) @ (v * value_decay)
# return output, current_kv

G. 与Transformer变体的比较

模型对比设置。我们将YOCOgRet和YOCOSWA与Transformer及其变体(包括H3【索引9,Hungry hungry hippos: Towards language modeling with state space models,2022,arXiv】、RetNet【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】、Mamba【索引14,Mamba: Linear-time sequence modeling with selective state spaces,2023,arXiv】和gRetNet(3.1节))进行比较。所有模型均为160M参数,12层,隐藏维度768。词嵌入和softmax投影的权重共享。

G.1 细粒度语言模型困惑度结果

细粒度困惑度。表8报告了语言建模的验证困惑度。遵循Zoology【索引1,Zoology: Measuring and improving recall in efficient language models,2023,arXiv】的设置,我们将困惑度分为Ar-Hit(预测的token是上下文中已见过的二元组)和First-Occur(预测的token无法从上下文中召回)。结果显示,YOCOgRet在整体验证集、关联召回(AR-Hit)和常规语言建模(First-Occur)方面均取得了最低的困惑度,表现最佳。

表8:语言建模的细粒度困惑度结果。我们报告了整体验证集和细粒度诊断集 [AET+23] 上的困惑度,即“AR-Hit”评估关联召回能力,“First-Occur”指示常规语言建模性能。
表8:语言建模的细粒度困惑度结果。我们报告了整体验证集和细粒度诊断集 [AET+23] 上的困惑度,即“AR-Hit”评估关联召回能力,“First-Occur”指示常规语言建模性能。

G.2 长上下文评估

长上下文任务评估。我们在ZeroSCROLLS【索引47,Zeroscrolls: A zero-shot benchmark for long text understanding,2023,arXiv】基准的四个任务上评估了上述架构的长上下文建模能力。我们将表8中的160M模型继续以16384的长度训练2B tokens。图12报告了答案在不同输入长度下的困惑度。在所有架构中,YOCO和Transformer在不同任务和长度上始终表现优于其他模型。

图12:长序列任务的困惑度随着输入长度的增加而降低。
图12:长序列任务的困惑度随着输入长度的增加而降低。