本文针对大型语言模型(LLM)中随着输入序列增长而导致的KV缓存(KV Cache)占用大量GPU内存和预填充(prefilling)延迟过高的问题,提出了一种名为YOCO(You Only Cache Once)的新型解码器-解码器(decoder-decoder)架构。
核心问题与研究目标:
标准的仅解码器(decoder-only)Transformer架构在自回归生成时,需要缓存所有先前计算过的键/值(KV)向量,以避免重复计算。当处理长序列时,这部分KV缓存会变得非常庞大,成为推理过程中的内存瓶颈,并导致预填充阶段耗时极长,这限制了长上下文语言模型的实际部署。例如,一个65B模型处理512K个token,KV缓存就需要约86GB的GPU内存。
创新点与主要贡献:
1. 提出YOCO架构:这是一种解码器-解码器架构,由一个自解码器(self-decoder)和一个跨解码器(cross-decoder)堆叠而成。自解码器使用高效的自注意力机制(如滑动窗口注意力或门控保持机制)来编码输入并生成一个全局的、共享的KV缓存。随后,跨解码器通过交叉注意力(cross-attention)机制来复用这个KV缓存进行后续的计算。
2. “只缓存一次”实现内存优化:由于只有自解码器的输出被用于生成全局KV缓存,且这个缓存被所有跨解码器层共享,因此整个模型的KV缓存大小与模型层数(L)无关,其内存复杂度从Transformer的O(N * L * D)降低到约O(N * D)。这极大地减少了GPU内存消耗,使得在有限的硬件上部署长上下文模型成为可能。
3. 预填充提前退出(Early Exit)机制:YOCO的计算流程允许在预填充阶段,计算完自解码器后即可提前退出,无需经过跨解码器。这一特性显著加快了预填充速度,例如,对于512K上下文,预填充延迟从180秒缩短到不足6秒,极大改善了用户体验。
4. 可扩展性与高性能:实验证明YOCO在模型大小、训练token数量和上下文长度方面都具有良好的可扩展性。
* 训练规模:将3B规模的YOCO模型扩展到万亿级训练token,其性能与StableLM等主流Transformer模型相当。
* 模型规模:从160M到13B的模型尺寸扩展曲线表明,YOCO与Transformer相比具有竞争力。
* 上下文长度:成功将YOCO扩展到100万(1M)上下文长度,并在“大海捞针”测试中实现了近乎完美的检索准确率。
总而言之,YOCO架构通过其独特的“只缓存一次”设计,在保持强大全局注意力能力的同时,显著降低了长上下文语言模型的推理内存占用和预填充延迟,提升了吞吐量,为未来原生支持长序列的大型语言模型提供了一个强有力的候选架构。
YOCO架构概述。本文提出的YOCO架构专为自回归建模(如大型语言模型LLM)设计。如图2所示,该解码器-解码器架构包含自解码器(self-decoder)和跨解码器(cross-decoder)两个部分。具体来说,YOCO由L个块堆叠而成,其中前L/2层为自解码器,其余为跨解码器。对于输入序列x = x1 · · · x|x|,其输入嵌入被打包成$X^0 = [x_1, · · · , x_{|x|}] \in \mathbb{R}^{|x|\times d_{\text{model}}}$,其中$d_{\text{model}}$是隐藏维度。模型首先通过自解码器计算上下文向量表示$X^l = \text{Self-Decoder}(X^{l-1})$, 其中$l \in [1, L/2]$,并利用$X^{L/2}$生成供跨解码器使用的KV缓存$\hat{K}, \hat{V}$。然后,模型计算$X^l = \text{Cross-Decoder}(X^{l-1}, \hat{K}, \hat{V})$, 其中$l \in [ L/2 + 1, L]$,以获得最终的输出向量$X^L$。
模块布局与差异。自解码器和跨解码器都遵循与Transformer【索引52,Attention is all you need,2017,NIPS】相似的块布局,即交错的注意力和前馈网络。本文还引入了pre-RMSNorm【索引59,Root mean square layer normalization,2019,NeurIPS】、SwiGLU【索引46,Glu variants improve transformer,2020,arXiv】和分组查询注意力(grouped-query attention)【索引2,Training generalized multi-query transformer models from multi-head checkpoints,2023,arXiv】作为改进。这两个部分的主要区别在于注意力模块:自解码器(2.1节)使用高效的自注意力机制(如滑动窗口注意力),而跨解码器(2.2节)则使用全局交叉注意力来关注由自解码器输出产生的共享KV缓存。
自解码器的计算过程。自解码器接收token嵌入$X^0$作为输入,并计算出中间向量表示$M = X^{L/2}$。其计算过程如下列公式所示:
$$\begin{aligned} \begin{aligned} Y^{l} & =\operatorname{ESA}\left(\operatorname{LN}\left(X^{l}\right)\right)+X^{l} \\ X^{l+1} & =\operatorname{SwiGLU}\left(\operatorname{LN}\left(Y^{l}\right)\right)+Y^{l} \end{aligned} \end{aligned}$$其中,$ESA(\cdot)$代表高效自注意力(efficient self-attention),$SwiGLU(X) = (\text{swish}(XW_G) \odot XW_1)W_2$,而$LN(\cdot)$使用RMSNorm【索引59,Root mean square layer normalization,2019,NeurIPS】。高效自注意力机制中也使用了因果掩码(Causal masking)。
高效自注意力的关键特性。高效自注意力模块的关键特性是其推理内存为$O(1)$,即KV缓存的数量是恒定的。例如,滑动窗口注意力【索引5,Generating long sequences with sparse Transformers,2019,OpenAI Blog】的缓存大小取决于窗口大小,而不是输入长度。关于高效自注意力模块的更多设计选择(例如,门控保持机制)将在第3节中详细介绍。
全局KV缓存的生成。首先,自解码器的输出$X^{L/2}$被用来为跨解码器生成全局KV缓存$\hat{K}$和$\hat{V}$:
$$\hat{K}=\operatorname{LN}(X^{L/2})W_K, \quad \hat{V}=\operatorname{LN}(X^{L/2})W_V$$其中$W_K, W_V \in \mathbb{R}^{d\times d}$是可学习的权重矩阵。然后,跨解码器层堆叠在自解码器之后,以获得最终的输出向量$X^L$。这些KV缓存$\hat{K}$和$\hat{V}$被所有$L/2$个跨解码器层复用。
跨解码器的计算过程。跨解码器的计算过程如下列公式所示:
$$\begin{aligned} \begin{aligned} \hat{Q}^l &= \text{LN}(X^l)W_Q^l \\ Y^l &= \text{Attention}(\hat{Q}^l, \hat{K}, \hat{V}) + X^l \\ X^{l+1} &= \text{SwiGLU}(\text{LN}(Y^l)) + Y^l \end{aligned} \end{aligned}$$其中Attention(·)是标准的多头注意力【索引52,Attention is all you need,2017,NIPS】,而$W_Q^l \in \mathbb{R}^{d\times d}$是一个可学习的矩阵。交叉注意力中也使用了因果掩码。由于交叉注意力与分组查询注意力【索引2,Training generalized multi-query transformer models from multi-head checkpoints,2023,arXiv】兼容,我们可以进一步节省KV缓存的内存消耗。在获得$X^L$后,一个softmax分类器执行下一个token的预测。
推理优势概述。除了具有竞争力的语言建模结果外,YOCO还显著降低了服务成本并提升了推理性能。详细的推理比较将在第4.4节中报告。
节省GPU内存并服务更多Token。表1比较了Transformer和YOCO的内存复杂度。具体来说,由于全局KV缓存被复用,并且高效自注意力只需要恒定的缓存,YOCO的缓存数量是$O(N + CL)$,其中N是输入长度,C是一个常数(例如,滑动窗口大小),L是层数。对于长序列,CL远小于N,因此大约需要$O(N)$的缓存,即“只缓存一次”。相比之下,Transformer解码器在推理过程中必须存储$N \times L$个键和值。因此,与Transformer解码器相比,YOCO为缓存大约节省了L倍的GPU内存。由于推理能力的瓶颈变成了KV缓存(图7b),我们的方法使我们能够在不耗尽GPU内存的情况下服务更多的token。增加的批处理大小也有利于提高推理吞吐量。
减少预填充时间并提高吞吐量。如图3所示,由于跨解码器复用自解码器的输出,我们可以在预填充阶段进入跨解码器之前提前退出。这种计算依赖的有趣特性极大地加快了预填充速度。首先,只需要一半的层进行前向计算,这意味着至少可以减少一半的预填充延迟。其次,自解码器的高效注意力模块通常速度很快。以512K上下文长度为例,我们可以将预填充延迟从180秒(使用Flash-Decoding和内核融合等优化推理的Transformer)减少到不到6秒(图9)。即使对于32K长度,YOCO在预填充时间上也有大约三倍的加速。表2比较了Transformer和YOCO注意力模块的预填充时间复杂度。
自解码器的设计选择。我们可以为自解码器选择各种高效的自注意力方法。只要该模块仅需要恒定的推理内存,自解码器的缓存内存复杂度就取决于层数。此外,一个好的模块选择可以改善训练和部署成本。在这项工作中,我们使用门控保持机制(gated retention)(第3.1节)或滑动窗口注意力(sliding-window attention)(第3.2节)。
门控保持机制(gRet)概述。门控保持(gRet,也称为gRetNet或RetNet-3)通过一种数据依赖的门控机制增强了保持网络(retention)【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】,它同时实现了序列建模的训练并行性、良好性能和低推理成本。在实验中,我们使用gRet作为默认的高效自注意力模块。该方法统一了并行、循环和分块循环三种计算范式。这三种表示是等价的,可以获得相同的计算结果。训练过程通常使用并行或分块循环范式,而推理阶段可以采用循环范式以实现恒定的KV内存。下面我们描述这三种表示。
并行表示。门控保持机制的定义如下:
$$Q = (XW_Q) \odot \Theta, \quad K = (XW_K) \odot \bar{\Theta}, \quad V = XW_V, \quad \Theta_n = e^{in\theta}$$其中$W_Q, W_K, W_V \in \mathbb{R}^{d\times d}$和$W_\gamma \in \mathbb{R}^{d\times 1}$是可学习的权重,温度项$\tau$鼓励$\gamma$趋近于1以获得更好的记忆能力【索引56,Gated linear attention transformers with hardware-efficient training,2023,arXiv】。数据控制的衰减是逐头的(head-wise)【索引20,Gateloop: Fully data-controlled linear recurrence for sequence modeling,2023,arXiv】,而不是逐元素的(element-wise),这样计算可以充分利用NVIDIA的张量核心(tensor cores)。关于其他设计的更多细节,请参考【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】。
循环表示。与公式(4)等价,门控保持的输出可以通过循环方式计算。对于第n个时间步,输出通过以下方式获得:
$$\begin{aligned} \begin{aligned} S_{n} &= \gamma_{n} S_{n-1} + K_{n}^{\intercal} V_{n} \\ \operatorname{gRet}(X_{n}) &= Q_{n} S_{n}, \quad n=1, \cdots, |x| \end{aligned} \end{aligned}$$其中$Q, K, V, \gamma$与公式(4)中相同。在自回归推理期间,自解码器维护$S_n$作为中间状态,以实现高效生成。
分块循环表示。分块表示是循环和并行表示的统一形式。给定块大小B,输出逐块计算。计算分为块内和跨块两部分。用[i]表示第i个块,即$x_{[i]} = x_{(i-1)B+1}, \cdots, x_{iB}$,我们按如下方式计算第i个块:
$$\beta_{(i-1)B+j} = \prod_{k=(i-1)B+1}^{(i-1)B+j} \gamma_k, \quad D_{[i]}(j,k) = \frac{\beta_{(i-1)B+k}}{\beta_{(i-1)B+j}} \text{ if } j \leq k \text{ else } 0$$其中$R_i$是第i个块的中间状态,$\beta$总结了数据控制的衰减$\gamma$。附录B中的证明显示了这些计算范式之间的等价性。分块范式结合了并行和循环的最佳特性,即与完全并行计算相比节省了FLOPs,并与循环计算相比减少了迭代次数。在训练和预填充阶段,分块表示提高了吞吐量并减少了GPU内存消耗。
多头门控保持。与多头注意力【索引52,Attention is all you need,2017,NIPS】和多尺度保持【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】类似,我们将门控保持应用于每个头,并将输出组合在一起:
$$\begin{aligned} \begin{aligned} \text{head}_i &= \text{gRet}(X) \\ Y &= \text{GroupNorm}_h(\text{Concat}(\text{head}_1, \cdots, \text{head}_n)) \\ \text{MHGR}(X) &= (\text{swish}(XW_G) \odot Y)W_O \end{aligned} \end{aligned}$$其中$W_G, W_O \in \mathbb{R}^{d\times d}$是可学习的矩阵,GroupNorm【索引54,Group normalization,2018,ECCV】对每个头进行归一化【索引55,Magneto: A foundation Transformer,2023,ICML】。我们还应用swish门来增加非线性【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】。
滑动窗口注意力。滑动窗口注意力【索引5,Generating long sequences with sparse Transformers,2019,OpenAI Blog】将注意力范围限制在一个固定的窗口大小C内。相比之下,普通的Transformer解码器会关注所有之前的token。在推理过程中,KV缓存的内存复杂度可以从$O(N)$降低到$O(C)$,即内存使用量是恒定的,而不是随序列长度增加而增加。与多头自注意力【索引52,Attention is all you need,2017,NIPS】类似,我们通过以下方式计算滑动窗口注意力的输出:
$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$其中$W_Q, W_K, W_V, W_O \in \mathbb{R}^{d\times d}$是可学习的矩阵,窗口因果掩码$B$控制每个查询只关注距离小于C的先前键。该模块也应用了预归一化(pre-normalization)和残差连接(residual connection)。
模型架构:
数据集与任务:
硬件配置:
软件配置:
实验结果:
分析结论:YOCO架构能够有效扩展到超长上下文(1M tokens),并在需要长距离信息检索和理解的任务上表现出色。
实验结果:
分析结论:YOCO在推理效率上具有数量级的优势,极大地降低了长上下文模型的部署成本,并显著改善了用户体验。
本文提出了一种用于大型语言建模的解码器-解码器架构(YOCO)。与Transformer相比,YOCO在实现显著更优的推理效率的同时,保持了具有竞争力的性能。实验结果表明,YOCO在各种设置下都取得了良好的效果,包括扩展训练token数量、扩展模型大小以及将上下文长度扩展到1M token。性能剖析结果也显示,YOCO将推理效率提升了几个数量级,尤其是在长序列建模方面。
未来工作展望:
1. YOCO + BitNet + Groq的结合:Groq通过将所有计算置于SRAM中实现极高吞吐量,但受限于内存容量。YOCO可以减少KV缓存内存,而BitNet可以减少模型权重内存。三者结合有望将LLM的部署成本降低几个数量级。
2. YOCO应用于多模态大模型:YOCO的布局天然支持使用多个自解码器,其交叉注意力层非常适合多模态融合。自解码器的因果依赖性也完美契合流式视频处理,可用于构建异步多模态大模型,避免不同数据流相互阻塞,这对于机器人等实时应用至关重要。
3. 优化KV缓存模块机制:YOCO架构明确了KV缓存模块,为开发原生内存机制开辟了新机会。
* 可以集成缓存压缩机制以获得更紧凑的内存。
* 可以构建索引以实现高效的键值检索,由于YOCO复用缓存,只需维护一个索引。
* 解耦的建模方式支持预缓存上下文,这对于原生RAG和LLM原生搜索引擎具有潜在价值。
块并行策略。为了减少通信频率并加速长序列训练,我们为YOCO引入了块并行。当训练序列极长时,将长序列划分到不同设备上至关重要【索引25,Sequence parallelism: Making 4d parallelism possible,2021,arXiv】、【索引11,Longnet: Scaling transformers to 1,000,000,000 tokens,2023,arXiv】。然而,总吞吐量往往受限于GPU通信【索引27,Ring attention with blockwise transformers for near-infinite context,2023,arXiv】。YOCO的跨解码器设计解耦了自注意力依赖,同时保留了建模能力,为分布式长序列训练带来了独特的优势。在自解码器中,依赖关系仅存在于相邻设备之间,因此通信量相对较小。在跨解码器中,all-gather操作仅对KV缓存触发一次,而不是在每一层都进行通信。这种硬件友好的架构为分布式长序列训练提供了更大的灵活性。
循环与分块表示的等价性证明。本节阐述了门控保持机制的循环表示和分块循环表示之间的等价性。对于输出$O_n$,n可以被拆分为$n = kB + r$,其中B是块大小。通过一系列推导,证明了分块计算(包含块内计算和跨块计算)的结果与原始的循环计算或并行计算是完全等价的。
$$O_n = \sum_{m=1}^{n} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m = \sum_{m=kB+1}^{n} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m + \sum_{m=1}^{kB} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m$$ $$\sum_{m=kB+1}^{n} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m = (Q_n K_{kB+1:n}^\intercal \odot \Gamma_{kB+1:n}) V_{kB+1:n}$$ $$\begin{aligned} \begin{aligned} \sum_{m=1}^{kB} \prod_{i=m+1}^{n} \gamma_i Q_n K_m^\intercal V_m &= (Q_n \prod_{i=kB+1}^{n} \gamma_i) \sum_{c=0}^{k-1} \sum_{m=1}^{B} (K_{m+cB}^\intercal V_{m+cB} \prod_{i=m+cB+1}^{(c+1)B} \gamma_i) \prod_{i=(c+1)B+1}^{kB} \gamma_i \\ &= (Q_n \prod_{i=kB+1}^{n-1} \gamma_i) \sum_{c=1}^{k} (K_{[c]}^\intercal (V_{[c]} \odot \zeta_{[c]})) \prod_{i=c+1}^{k} \alpha_i \\ &= (Q_n \prod_{i=kB+1}^{n-1} \gamma_i) R_{i-1} \end{aligned} \end{aligned}$$ $$R_{i}=K_{[i]}^{\intercal}(V_{[i]} \odot \zeta_{[i]})+\alpha_{i} R_{i-1}$$ $$O_{[n]}=\sum_{m=k B+1}^{[n]} \beta_{[n]} Q_{[n]} K_m^{\top} V_m+\sum_{m=1}^{k B} \beta_{[n]} Q_{[n]} \prod_{i=m+1}^n \gamma_i K_m^{\top} V_m$$YOCO-3B模型配置。本节描述了4.1节中使用的超参数。隐藏维度为3072,层数为26,查询头数量为24,键/值头数量为8(使用GQA)。不含嵌入层的总参数量为2.83B。训练批次大小为4M tokens,训练长度为4096。优化器为AdamW【索引22,Decoupled weight decay regularization,2019,ICLR】,$\beta = (0.9, 0.95)$。最大学习率为$3.2 \times 10^{-4}$,有1000个warmup步骤,并线性衰减至$1.28 \times 10^{-5}$,总调度设置为5T tokens。
扩展性实验配置。本节描述了4.2节中使用的超参数。表6报告了不同模型大小所使用的隐藏维度、层数和头数。门控保持的头维度设为256。为了对齐参数量,Transformer的FFN大小为$8/3 d$,而YOCO的FFN大小为$3d$。训练长度为2048,批次大小为0.25M tokens。优化器为AdamW,学习率和warmup步数根据模型大小有所不同。模型训练了40k步,即10B tokens。
长上下文训练配置。在4.3节中,我们逐步将上下文长度扩展到1M tokens。长度调度为64K, 256K, 和1M。我们对长度超过训练长度的文档进行上采样。表7显示了我们在每个阶段使用的不同RoPE $\theta$和学习率。
门控保持的三种计算范式伪代码。本节提供了门控保持(3.1节)三种计算范式的伪代码。并行实现(Parallel implementation)能够充分利用GPU进行并行训练。循环范式(Recurrent paradigm)可实现低成本推理。分块保持(Chunkwise retention)结合了上述两者的优点(即在块内并行,在块间循环),对于长序列具有线性内存复杂度。
# def RecurrentRetention(
# q, k, v, # bsz * num_head * dim
# past_kv, # bsz * num_head * dim * dim
# gt # bsz * num_head * 1 * 1
# ):
# gt = F.logsigmoid(gt) / gate_logit_normalizer
# current_kv = gt.exp() * past_kv + k.unsqueeze(-1) * v.unsqueeze(-2)
# output = torch.sum(q.unsqueeze(-1) * current_kv, dim=-2)
# output = group_norm(output)
# return output, current_kv
# def ChunkwiseRetention(
# q, k, v, # bsz * num_head * chunk_size * dim
# past_kv, # bsz * num_head * dim * dim
# gt): # bsz * num_head * chunk_size
# gt = F.logsigmoid(gt).cumsum(-1) / gate_logit_normalizer
# cross_retention = (q @ past_kv) * gt[..., None].exp()
# inner_retention = ParallelRetention(q, k, v, gt)
# retention = inner_retention + cross_retention
# output = group_norm(retention)
# value_decay = (-gt + gt[:, :, :, -1, None]).exp()[..., None]
# chunk_decay = gt[..., -1].exp()
# current_kv = chunk_decay * past_kv + k.transpose(-1, -2) @ (v * value_decay)
# return output, current_kv
模型对比设置。我们将YOCOgRet和YOCOSWA与Transformer及其变体(包括H3【索引9,Hungry hungry hippos: Towards language modeling with state space models,2022,arXiv】、RetNet【索引45,Retentive network: A successor to transformer for large language models,2023,arXiv】、Mamba【索引14,Mamba: Linear-time sequence modeling with selective state spaces,2023,arXiv】和gRetNet(3.1节))进行比较。所有模型均为160M参数,12层,隐藏维度768。词嵌入和softmax投影的权重共享。
细粒度困惑度。表8报告了语言建模的验证困惑度。遵循Zoology【索引1,Zoology: Measuring and improving recall in efficient language models,2023,arXiv】的设置,我们将困惑度分为Ar-Hit(预测的token是上下文中已见过的二元组)和First-Occur(预测的token无法从上下文中召回)。结果显示,YOCOgRet在整体验证集、关联召回(AR-Hit)和常规语言建模(First-Occur)方面均取得了最低的困惑度,表现最佳。
长上下文任务评估。我们在ZeroSCROLLS【索引47,Zeroscrolls: A zero-shot benchmark for long text understanding,2023,arXiv】基准的四个任务上评估了上述架构的长上下文建模能力。我们将表8中的160M模型继续以16384的长度训练2B tokens。图12报告了答案在不同输入长度下的困惑度。在所有架构中,YOCO和Transformer在不同任务和长度上始终表现优于其他模型。