Mixture-of-Depths Attention

作者/机构: Lianghui Zhu (华中科技大学电子信息与通信学院, 字节跳动Seed), Yuxin Fang (字节跳动Seed)†, Bencheng Liao (华中科技大学电子信息与通信学院, 字节跳动Seed), Shijie Wang (字节跳动Seed), Tianheng Cheng (字节跳动Seed), Zilong Huang (字节跳动Seed), Chen Chen (字节跳动Seed), Lai Wei (字节跳动Seed), Yutao Zeng (字节跳动Seed), Ya Wang (字节跳动Seed), Yi Lin (字节跳动Seed), Yu Li (字节跳动Seed), Xinggang Wang (华中科技大学电子信息与通信学院)#
(†项目负责人, #通讯作者)


图1 我们提出混合深度注意力(MoDA)以一种动态且硬件高效的方式解决现代大语言模型的信息稀释问题。与普通的因果序列注意力相比,MoDA额外允许查询(query)关注来自先前层的同一查询位置的深度记忆,即深度KV对 {Ki, Vi}l−1i=0。


图2 在15亿参数设置下,比较MoDA和强大的开源基线模型OLMo2 [27]的验证损失和下游任务性能。使用MoDA的模型在C4 [30]验证集上实现了更低的损失,并在下游任务(如HellaSwag [48]、WinoGrande [32]和ARC-Challenge [10])上表现优于OLMo2。

A1 主要贡献

核心问题

大型语言模型(LLM)的最新进展主要由上下文长度、训练数据、模型宽度和模型深度四个维度的扩展驱动。然而,在实践中,扩展往往更侧重于数据、上下文和宽度,因为它们的优化行为和系统效率更容易实现。相比之下,深度虽然具有强大的表示潜力,但仍未被充分利用。原则上,更深的网络可以支持更丰富的层次化计算,但现代Transformer由于优化问题和信息稀释(information dilution),常常无法将增加的层数转化为相应的性能提升。信息稀释指的是在浅层形成的有用特征在经过重复的残差更新后会逐渐被淡化,导致在深层难以恢复。

研究目标

本文的核心问题是:如何扩展模型深度,同时保持优化稳定并防止信息稀释? 现有的方法,如标准的残差连接(ResNet-style),虽然改善了深度网络的优化稳定性,但仍将深度历史压缩到单一的隐藏状态轨迹中,未能解决信息稀释问题。而密集跨层连接(DenseNet-style)虽然能保留更丰富的逐层历史,但其参数增长在LLM规模下是巨大的,限制了其应用。本文的目标是设计一种新的机制,能够在表达能力、效率和硬件友好性之间取得更好的平衡,使模型能够自适应地从早期层读取有用的状态,从而有效利用深度信息。

创新点

为解决上述问题,本文提出了混合深度注意力(Mixture-of-Depths Attention, MoDA),这是一种统一的注意力机制,其核心创新点如下:

  • 提出MoDA,一种用于动态混合序列和深度的统一注意力公式:MoDA允许每个注意力头同时关注当前层的序列KV对和来自所有先前层的深度KV对。这种数据依赖的方式改善了深度信息的聚合,并解决了现代LLM的信息稀释问题。通过将序列和深度信息的聚合置于一个统一的softmax操作中,MoDA提供了一个统一的表示空间。

  • 提出一种硬件高效的融合算法,使MoDA在长上下文LLM训练中变得实用:为了使MoDA在实践中高效,本文开发了一个硬件感知的实现。该实现通过重新组织深度流张量以实现连续内存访问,并将序列和深度注意力融合到一个前向传递中,共享在线softmax状态。该融合核在64K序列长度下达到了FlashAttention-2效率的97.3%,同时保持了可接受的数值精度。

  • 提供了广泛的实证证据,证明MoDA在多个模型规模上持续且显著优于强大的开源基线:在7亿和15亿参数规模上,使用4000亿词元的OLMo2配方进行训练,MoDA在10个验证基准上的平均困惑度降低了0.2,在10个下游任务上的平均性能提高了2.11%,而计算开销(FLOPs)仅增加了可忽略的3.7%。实验还发现MoDA与post-norm结合比与pre-norm结合性能更好。

A2 方法细节

2.1 预备知识

自注意力机制基础。大多数现代大型语言模型都基于Transformer架构【39,Attention is all you need, 2017, NeurIPS】,其中自注意力是主要的词元混合算子。给定一个由T个词元组成的序列$X = (x_1, x_2, \ldots, x_T) \in R^{T \times D}$(其中D为隐藏维度),自注意力首先通过可训练的矩阵$W_Q \in R^{D \times (H_q d)}$和$W_K, W_V \in R^{D \times (H_k d)}$将词元投影为查询(Q)、键(K)和值(V)。在分组查询注意力(GQA)【2,Gqa: Training generalized multi-query transformer models from multi-head checkpoints, 2023, EMNLP】下,$H_q = G H_k$,$H_k = H_v$,且$D = H_q d$:

$$Q=X W_{Q}, \quad K=X W_{K}, \quad V=X W_{V},$$

其中$Q \in R^{T \times (H_q d)}$以及$K, V \in R^{T \times (H_k d)}$。注意力算子计算查询和键之间的成对相似度,应用softmax获得每个头的注意力权重$A_h \in R^{T \times T}$,并返回值的加权和:

$$\text{Attention}(Q, K, V) = \text{Concat}_{h=1}^{H_q} \left( \text{softmax} \left( \frac{Q_h K_{\phi(h)}^T}{\sqrt{d}} + \mathcal{M} \right) V_{\phi(h)} \right)$$

其中$Q_h \in R^{T \times d}$,$K_j, V_j \in R^{T \times d}$,$\phi(h) = \lceil h/G \rceil$将每个查询头映射到其共享的键值头。这里,$M \in R^{T \times T}$是一个加性注意力掩码。对于因果注意力,若$j \le i$,则$M_{ij} = 0$,否则$M_{ij} = -\infty$。对于全注意力,$M$是全零矩阵。

2.2 沿深度流堆叠Transformer

深度流信息传播机制的探讨。深度神经网络在引入残差连接【16,Deep residual learning for image recognition, 2016, CVPR】后在各个领域取得了突破。扩展性研究【18,19,21】进一步表明,增加深度可以显著提高性能【33,36】。这引出了一个自然的问题:残差连接是沿深度流传播信息的最佳机制吗? 我们可以从“读取、操作、写入”这三个步骤的角度来看待一个Transformer块,并以此来描述堆叠Transformer块的不同机制。为了清晰起见,前两种机制(深度残差【16】和深度密集【20,28】)是用于定义深度流设计空间的参考设计。我们引入深度注意力作为一种中间形式和概念桥梁。本节的主要技术贡献始于混合深度注意力(MoDA),它在一个统一的softmax算子中融合了序列和深度信息的检索。

深度残差连接 (Depth Residual)。在深度残差连接【16,35】中,“读取”步骤是恒等映射,“写入”步骤是加法操作。“操作”步骤是词元混合算子,即注意力或前馈网络(FFN),表示为$F(\cdot)$。如图3(a)所示,深度残差的结构可以公式化为:

$$X_{l}=X_{0}+\sum_{i=1}^{l-1} \mathcal{F}\left(X_{i}, \mathcal{W}_{i}\right),$$

其中$W_i$是第$i$层的可训练权重矩阵集合。


图3 利用深度流的机制的概念性比较。(a) 深度残差(Depth Residual)【16】是沿深度的标准残差连接:它读取当前表示并通过加法写回。(b) 深度密集(Depth Dense)【20, 28】读取一组历史表示并将它们线性投影回宽度D;它通过沿深度进行拼接来写回,保留所有中间状态。(c) 我们引入深度注意力(Depth Attention)作为一种中间形式,它使用注意力以数据依赖的方式读取历史深度KV对。它通过沿深度拼接当前层的键和值来写回。(d) 我们提出深度注意力的升级版,即混合深度注意力(MoDA),它将深度注意力与标准序列注意力相结合。它将当前层的输出及其KV对都写入深度流,供后续层使用。

深度残差连接的局限性。该公式缓解了梯度消失问题,使得训练深度网络成为可能。然而,深度流通过重复的叠加被持续压缩到一个固定大小的张量$X_l \in R^{T \times D}$中,这会稀释显著特征并导致信号退化。

深度密集连接 (Depth Dense)。为了减轻信号退化,深度密集方法【20,Densely connected convolutional networks, 2017, CVPR】【28,Denseformer: Enhancing information flow in transformers via depth weighted averaging, 2024, NeurIPS】沿深度流连接所有层。在“读取”步骤,它们通过将先前表示的集合$\{X_i \in R^{T \times D}\}_{i=0}^{l-1}$线性投影回形状$T \times D$来形成第$l$层的输入。在“写入”步骤,该层的输出与历史集合沿深度进行拼接。如图3(b)所示,深度密集的结构可以公式化为:

$$\{X_i\}_{i=0}^l = \{X_0, \mathcal{F}(\{X_0\}, \mathcal{W}_1), \mathcal{F}(\{X_0, X_1\}, \mathcal{W}_2), \cdots, \mathcal{F}(\{X_i\}_{i=0}^{l-1}, \mathcal{W}_l)\}$$

其中$W_i$是第$i$层的可训练权重矩阵集合。

深度密集连接的优缺点。深度密集连接无损地通过深度传播信息,因为拼接不会压缩历史集合。然而,它们带来了高昂的成本并强制执行固定的连接模式:计算量在主要项上以$O(T L^2 D^2)$增长,这对于大型模型是 prohibitive 的。

深度注意力 (Depth Attention)。为了在保留自适应连接性的同时降低成本,我们提出了深度注意力,它使用注意力以数据依赖的方式读取历史深度信息,如图3(c)所示。在“读取”步骤,从GQA组的角度($H_k d = D/G$),我们将一个查询组的表示记为$Q_{l-1} \in R^{T \times DG}$,相应的历史键值集合记为$\{K_i \in R^{T \times DG}\}_{i=0}^{l-1}$和$\{V_i \in R^{T \times DG}\}_{i=0}^{l-1}$。得到的输入$X_{in}^l$随后被送入“操作”步骤:

$$X_{l}^{\text {in }}=\operatorname{Attention}\left(Q_{l-1},\left\{K_{i}\right\}_{i=0}^{l-1},\left\{V_{i}\right\}_{i=0}^{l-1}\right),$$

其中注意力是沿深度维度执行的:对于词元$t$,查询$Q_{l-1, t}$只关注来自跨层同一词元位置的深度键和值$\{K_{i,t}, V_{i,t}\}_{i=0}^{l-1}$。在“操作”步骤之后,当前层的输出$X_{out}^l$被送入“写入”步骤,该步骤产生新的查询/键/值投影:

$$Q_l = X_l^{\text{out}} W_{Q,l}^{\text{W}}, \qquad K_l = X_l^{\text{out}} W_{K,l}^{\text{W}}, \qquad V_l = X_l^{\text{out}} W_{V,l}^{\text{W}},$$

表1 深度流机制的渐进复杂度。这里,T是序列长度,D是模型宽度,G是分组查询注意力(GQA)【2】的组大小,Hk是键头的数量(等于值头Hv),Hq是查询头的数量(=GHk),d是头维度,L是层数。我们报告主要项并省略常数因子。

深度注意力的写操作与状态更新。其中,$W_{WQ,l}^W, W_{WK,l}^W, W_{WV,l}^W \in R^{D \times DG}$是第$l$层“写入”操作的可训练矩阵,$Q_l, K_l, V_l \in R^{T \times DG}$表示每个组的投影。我们将$K_l$和$V_l$沿深度拼接以供未来的读取,而$Q_l$则向前传递到下一层。

深度注意力的成本优势。与深度密集连接相比,深度注意力以低得多的成本自适应地读取历史信息。其计算量按$O(T L^2 D)$扩展,比深度密集小一个因子$1/D$。

混合深度注意力 (MoDA) 的提出。在深度注意力的基础上,我们现在提出混合深度注意力(MoDA)。MoDA将深度级别的信息添加到标准的序列级别注意力中,并将这些操作融合成一个单一的算子。如图1和图3(d)所示,MoDA读取当前隐藏状态$X_{l-1}$和历史深度KV流$\{(K_i, V_i)\}_{i=0}^{l-1}$。在“操作”步骤中,我们应用MoDA,使每个词元能够同时关注序列级别的键和值,以及其自身的历史深度级别的键和值,所有注意力分数在一个单一的softmax函数下联合归一化。MoDA的实现细节在算法1中呈现。在“写入”步骤,对于注意力层,我们将当前层的键值对附加到深度流中,以便后续层可以访问它们。对于FFN层,我们通过一个轻量级的KV投影获得其相应的键值对。

MoDA的整体机制与优势。总的来说,MoDA提供了一种高效、数据依赖的机制来利用深度历史,其开销远低于密集的跨层连接。此外,在一个softmax操作中聚合序列和深度信息提供了一个统一的表示空间。

深度流机制的复杂度分析。复杂度分析对于现代LLM设计至关重要,我们也对深度感知设计(如深度密集、深度注意力和MoDA)进行了详细的复杂度分析。表1报告了完整的复杂度和主要的渐进项,其中T是序列长度,D是模型宽度,L是层数,d是头维度,G是GQA组大小。值得注意的是,$H_q = G H_k$。

复杂度分析结论。从表1中可以看出,深度密集的复杂度主要受二次深度增长的影响。其参数项为$O(L^2 D^2)$,解码缓存为$O(LD)$,解码和预填充的FLOPs都包含二次深度和二次宽度的项,即$O(L^2 D^2)$和$O(T L^2 D^2)$。所提出的深度注意力是一种数据依赖的方法,它消除了跨深度主要的二次宽度投影累积,将参数减少到$O(LD^2)$。它还将缓存降低到$O(LD/G)$,并将解码和预填充的计算量分别降低到$O(L^2 D)$和$O(T L^2 D)$。与深度注意力相比,MoDA保持了相同的有利FLOPs阶数和缓存阶数,但进一步将参数复杂度从$O(LD^2)$降低到$O(LD^2/G)$。关键原因是MoDA重用了序列注意力的查询投影,因此没有引入额外的深度查询投影。特别是在GQA设置中,只需要分组的深度键/值投影。这使得MoDA成为表1中最具参数效率的选项,同时保留了宽度线性的计算行为和低缓存扩展性。

MoDA统一Softmax的优势。总的来说,表1显示MoDA在保持注意力的数据依赖行为的同时,避免了密集跨层连接的主要的二次深度参数增长开销。MoDA用一个统一的softmax算子聚合序列和深度信息,这在实践中提供了更好的表示和效率,尤其是在L大和T长的场景下。

A3 背景知识与设计原则

3. 硬件感知的MoDA高效实现

硬件感知实现的需求。使用PyTorch【29,Pytorch: An imperative style, high-performance deep learning library, 2019, NeurIPS】天真地实现MoDA需要对历史深度状态进行非连续读取,这会降低GPU的利用率。我们开发了一种硬件感知的实现,通过重组深度流张量来实现连续的内存访问和融合计算。


图4 MoDA深度缓存访问的硬件视图。左:与Flash兼容的硬件高效MoDA为每个序列保留一个长度为 T × L 的深度KV缓存,因此每个查询可能扫描一个很长的拼接深度KV。右:块感知的MoDA按块大小C对查询进行分组,并按块重组深度KV,将每个块的有效深度跨度从 T × L 减少到 (C × L)/G,其中G是GQA组数。这种布局提高了深度KV计算效率并减少了内存访问开销。

3.1 预备知识

现代GPU的并行计算特性。现代GPU针对吞吐量导向的大规模数据并行工作负载进行了优化,其中相同的操作并行应用于许多元素【12,13,44-46】。因此,高效的注意力核应该被组织成暴露规则的、大规模并行的计算,而不是不规则的逐元素控制流。

流式多处理器 (SMs)。NVIDIA GPU由许多SM组成,它们是并行执行和资源管理的片上基本单元。高利用率需要足够多的独立块来保持许多SM处于活动状态。在具有长上下文序列和相对较小批量大小的大型语言模型(LLM)训练中,沿时间维度的并行化尤其重要。

计算单元 (CUDA Cores vs. Tensor Cores)。在每个SM内部,指令被分派到不同的执行单元。CUDA核心支持通用算术指令,而Tensor核心为结构化矩阵乘法累加操作提供高得多的吞吐量。因此,实用的高性能核应最大化规则的矩阵乘法式计算,以更好地利用Tensor核心。

内存层级 (HBM vs. SRAM)。端到端性能由计算吞吐量和数据移动共同决定。HBM提供大容量但访问延迟较高,而片上SRAM结构(即寄存器、共享内存和缓存)速度快得多但大小有限。因此,一个关键的设计原则是改进分块(tiling)和数据重用,使热数据保留在片上并最小化HBM流量。

硬件原理对MoDA设计的启发。这些原则直接激发了我们硬件感知的MoDA设计。我们重组了深度KV布局并融合了计算,以减少非连续内存访问并提高有效计算利用率。

3.2 MoDA的硬件感知考量

Flash兼容的深度KV布局。用显式的PyTorch for循环在历史深度KV上天真地实现深度注意力在GPU上通常很慢,因为它会引发不规则的类似gather的内存访问,并且未能充分利用对张量核心友好的块计算。我们的第一步是一个与Flash兼容的深度KV布局,它将深度缓存沿单个长度为$T \times L$的轴展平。因此,对于每个序列位置$t$,其$L$个深度状态被连续存储。这样,每个查询只需要映射到其对应的深度范围$[tL, (t+1)L)$即可访问正确的深度KV切片。这将深度查找转变为连续的块读取,并使深度阶段与FlashAttention风格的核兼容。尽管这种展平的公式比显式的PyTorch for循环快得多,但它在深度阶段仍然引入了计算效率问题。在深度得分矩阵$S_{depth} \in R^{T \times (TL)}$中,只有一个块对角区域是有效的。具体来说,对于查询行$i_q$,只需要深度列索引$j_d \in [i_q L, (i_q+1)L)$,而其余条目被掩码。我们将此比率定义为深度利用率,即,如果在整个$T \times (TL)$矩阵上密集计算,深度利用率为$\eta_{depth} = \frac{T \cdot L}{T \cdot (T \cdot L)} = \frac{1}{T}$。

块感知的深度KV布局。如图4所示,与Flash兼容的深度KV布局迫使每个查询块遍历一个长度为$T \times L$的长向量化拼接深度轴,这不利于深度利用率。因此,我们以块感知的方式重组深度KV,即将查询分为块,每个块只访问其覆盖范围对应的深度KV跨度。从块感知的角度来看,一个长度为$C$的查询块与一个大小为$C \times L$的局部深度KV区域配对,该区域通过拼接覆盖的$C$个序列位置的$L$个深度状态构建。因此,核在这个打包的$C \times L$区域上计算分块的深度注意力,而不是为每个块扫描全局的$T \times L$深度轴。这种局部布局大大减少了来自被掩码的、超出范围的深度条目的不必要的HBM流量,并将深度利用率提高到$\eta_{depth} = \frac{T \cdot L}{T \cdot (C \cdot L)} = \frac{1}{C}$。

组感知的深度KV计算。我们的关键观察是,在映射$T_q = G T_{kv}$下,G个相邻的查询行共享相同的基础时间索引$\lfloor i_q/G \rfloor$,因此可以重用相同的深度KV块。基于此,我们设计了一个组感知的深度KV计算,即对于一个长度为C的查询块,只有$C/G$个基础时间行是唯一的,所以所需的深度跨度是$(C/G) \times L$而不是$C \times L$。在融合的块矩阵乘法和掩码执行下,这将有效深度利用率提高到$\frac{G \times L}{C \times L} = \frac{G}{C}$。相同的基础时间映射在两个掩码中一致使用,即$\lfloor i_q/G \rfloor \ge i_k$用于序列因果关系,$\lfloor i_q/G \rfloor = \lfloor j_d/L \rfloor$用于深度匹配。值得注意的是,$i_k$是序列键索引,而$j_d$是展平的深度列索引。在实践中,我们还将查询块边界与G对齐,即使块大小可被G整除,以避免在一个tile内处理跨组边界并简化向量化执行。

3.3 硬件高效的MoDA实现

算法准备阶段。算法1遵循组感知映射$T_q = G T_{kv}$。输入是查询$Q \in R^{T_q \times (H_k d)}$,序列键/值$K, V \in R^{T_{kv} \times (H_k d)}$,以及深度键/值$K_{depth}, V_{depth} \in R^{(T_{kv} L) \times (H_k d)}$,输出为$O \in R^{T_q \times (H_k d)}$,其中$H_k d = D/G$。为清晰起见,$b_q, b_s, b_d$表示块索引,而$i_q, i_k, j_d$表示块内的元素索引。在进入主循环之前,所有张量都被分块成硬件友好的块,并且每个查询块都与G对齐。对于每个查询块$b_q$,我们从HBM加载$Q[b_q]$到SRAM,并初始化片上在线softmax状态$(m, acc, o)$,其中$m$是运行中的最大logit,acc是运行中的softmax归一化因子,o是运行中的未归一化输出累加器。对于$b_q$中的每个查询行索引$i_q$,我们计算其基础时间索引$t_{base}(i_q) = \lfloor i_q/G \rfloor$,并定义$t_{start}^{base} = \min_{i_q \in b_q} t_{base}(i_q)$和$t_{end}^{base} = \max_{i_q \in b_q} t_{base}(i_q) + 1$。半开区间$[t_{start}^{base}, t_{end}^{base})$随后被序列和深度循环重用,确保索引一致性。直观地说,如果G=4且一个查询块包含行$i_q = 8, \ldots, 15$,则$t_{base}(i_q) \in \{2, 3\}$,因此$t_{start}^{base} = 2$,$t_{end}^{base} = 4$。

序列注意力循环。序列阶段包含两个循环,它们都重用相同的累加器状态$(m, acc, o)$。对于完全可见的块($b_s < t_{start}^{base}$),我们从HBM加载$(K[b_s], V[b_s])$到SRAM,计算$S = Q[b_q] K^T[b_s] / \sqrt{d}$,并调用OnlineSoftmaxUpdate。在这个区域,所有的键都早于当前的查询基础时间,因此不需要因果掩码。对于边界块($t_{start}^{base} \le b_s < t_{end}^{base}$),使用相同的流水线,并应用分组因果掩码$\lfloor i_q/G \rfloor \ge i_k$。因此,来自多个序列块的logits被累加到一个在线softmax状态中,而没有中间的HBM物化。这等效于处理一个更长的拼接键序列,同时保持计算是块状的。

深度注意力循环。在序列累加之后,核进入深度循环,其展平的深度索引为$b_d \in [t_{start}^{base} L, t_{end}^{base} L)$。因子L将一个基础时间索引映射到其长度为L的连续深度跨度。对于每个深度块,从HBM加载$(K_{depth}[b_d], V_{depth}[b_d])$到SRAM,并计算深度logits $S_d = Q[b_q](K_{depth}[b_d])^T/\sqrt{d}$。然后应用一个掩码:

$$\begin{aligned} mask(i_q, j_d) = \mathbf{1} \left[ \left\lfloor \frac{i_q}{G} \right\rfloor = \left\lfloor \frac{j_d}{L} \right\rfloor \right] := \begin{cases} 1, & j_d \in \left[ L \left\lfloor \frac{i_q}{G} \right\rfloor, L \left( \left\lfloor \frac{i_q}{G} \right\rfloor + 1 \right) \right), \\ 0, & \text{otherwise.} \end{cases} \end{aligned}$$

该掩码只保留与查询行具有相同基础时间索引的深度条目。掩码后的logits随后被传递给OnlineSoftmaxUpdate,重用与序列阶段相同的$(m, acc, o)$状态。最后,我们通过$o \leftarrow o/acc$在片上进行一次归一化,将$O[b_q]$写回HBM,并在所有查询块处理完毕后返回O。

3.3.1 效率比较

表2 硬件高效MoDA与FlashAttention-2 Triton核在“前向和后向”设置下的效率比较。我们报告了在三种扩展设置下的运行时间(ms)、深度利用率(ηdepth)和相对额外时间。这里,B表示批量大小,d表示头维度,C表示块大小。所有实验均在A100 GPU上以bfloat16数据类型运行。

效率对比实验设置。表2报告了在受控设置下,硬件高效MoDA与FlashAttention-2 Triton的端到端“前向和后向”运行时间。我们在每个块中固定其余因素(B=1, d=64, C=64),对序列长度T、GQA组大小G和模型深度L进行扫描。除了原始运行时间(ms),我们还报告了深度利用率和MoDA的相对额外时间百分比。

序列长度扩展下的效率分析。当扩展序列长度时,即让T从4096增加到65536,在G=8, L=64的条件下,两个核都遵循预期的增长趋势,而MoDA的相对额外时间百分比从25.86%持续下降到2.73%。这表明随着序列计算变得占主导地位,额外的深度路径开销被越来越多地分摊。

GQA组大小和模型深度扩展下的效率分析。当在固定的T=16384下,将组大小G从2扩展到32时,深度利用率从3.12%上升到50.00%,额外时间百分比从27.07%下降到2.84%。相反,当在固定的T=16384和G=8下扩展模型深度时,FlashAttention-2的运行时间保持在116.700毫秒,而MoDA的运行时间从127.661毫秒增加到167.958毫秒。相应地,额外时间百分比从8.59%上升到30.52%,这与更深的深度流引入更多深度-KV处理的事实相符。

效率对比总结。总的来说,结果表明,所提出的实现具有可预测的线性扩展行为,并在长序列、高利用率的场景中保持高效。

A4 实验环境与结果

实验环境

  • 模型架构: 本文主要在7亿和15亿参数规模的解码器专用语言模型上进行实验。对于这两个规模的模型,均采用了分组查询注意力(GQA)【2】。
  • 数据集: 模型在OLMo2【27】数据集的4000亿词元子集上进行训练。
  • 硬件与软件配置:

    • 硬件: 实验在NVIDIA A100 GPU上运行(从效率分析部分可知)。
    • 软件: 所有模型均使用bfloat16(bf16)精度进行训练。全局批量大小设置为1024,上下文序列长度为4096。学习率调度、AdamW【25】优化器等训练配置遵循OLMo2【27】的实现。
  • 评估基准:

    • 下游任务: PiQA【5】, HellaSwag【48】, WinoGrande【32】, OpenBookQA【26】, BoolQA【9】, SciQA【3】, COPA【31】, MMLU【17】, ARC-easy (ARC-E) 和 ARC-challenge (ARC-C)【10】。
    • 困惑度(PPL): 报告了训练集PPL,C4【30】验证集PPL,以及在C4【30】, ICE【27】, m2d2-s2orc【24】, Pile【14】, Wiki-text【27】和dolma【34】验证集上的分领域验证PPL。

实验结果

MoDA变体分析

表3 不同混合深度注意力(MoDA)变体在训练集、C4验证集和下游基准上的性能。我们用4000亿词元训练了7亿参数的模型。对于MoDA设置:“Sequence KV”表示每个词元只关注序列键/值,可视为普通的注意力机制。“Depth KV”表示每个词元关注其深度键/值。“Extra FFN KV Proj.”表示进一步将FFN的输入X投影到深度键/值,然后用于后续的注意力操作。“Extra Attn KV Proj.”表示设置独立的深度键/值投影,而不是重用序列注意力的原始键/值投影。模型宽度D、GQA组大小G、序列长度T分别设为1024、2和4096。我们还报告了模型的参数量和FLOPs。

实验设置: 在7亿参数模型上比较不同MoDA变体的性能。所有模型使用相同的学习率调度策略。为了公平比较,引入了两个基线:标准的vanilla attention (OLMo2)和一个参数量相当的、增加了两层的OLMo2模型。

结论:

  • (i) 仅增加深度KV即可显著提升性能: 与基线(表3,第1行)相比,仅添加深度KV(直接复用前一层的序列KV,不增加参数)的模型(第3行),在只增加0.12% FLOPs的情况下,训练PPL降低0.41,C4验证PPL降低0.11,下游任务平均分提高1.17。
  • (ii) FFN层的深度KV很重要: 在仅有注意力层深度KV的基础上,为FFN层增加一个轻量级的KV投影(第4行),性能进一步提升。与第3行相比,训练PPL降低0.18,C4验证PPL降低0.27,下游任务平均分提高0.77。与参数量相当的基线(第2行)相比,该模型性能更优,证明了FFN的深度信息对MoDA的贡献。
  • (iii) 额外的注意力KV投影效果饱和: 在第4行的基础上,再为注意力层增加一个专门的深度KV投影(第5行),性能提升微乎其微,但参数和FLOPs开销显著增加,表明这种做法已接近饱和。
  • 设计原则: 实验揭示了MoDA的设计原则:注入深度信息是有效的,但效果对投影引入的位置高度敏感。复用注意力侧的深度KV成本极低且效果好,而增加FFN侧的深度KV则达到了最佳的精度-效率权衡。因此,后续实验采用第4行的配置作为默认的MoDA变体。

MoDA随模型尺寸扩展的性能

表4 提出的MoDA模型在不同模型尺寸下在下游基准测试上的性能。我们在OLMo2数据集的4000亿词元上训练了7亿和15亿参数的模型。宽度D、GQA组大小G、序列长度T分别设置为1024、2和4096。最佳性能用粗体标出。

表5 提出的MoDA模型在不同模型尺寸下的各领域验证困惑度。我们在OLMo2数据集的4000亿词元上训练了7亿和15亿参数的模型。宽度D、GQA组大小G、序列长度T分别设置为1024、2和4096。较低的困惑度表示更好的性能,并用粗体标出。

实验设置: 在相同的4000亿词元训练预算下,将模型规模从7亿扩展到15亿,比较MoDA与基线OLMo2的性能。

结论:
* (i) MoDA在不同模型规模下均带来稳定的下游任务平均性能提升: 如表4所示,在7亿模型上,平均分提升1.76(57.11 -> 58.87);在15亿模型上,平均分提升2.11(62.28 -> 64.39)。
* (ii) 下游任务性能提升广泛覆盖各类任务: 在常识、推理和知识型任务上均观察到显著提升,例如在15亿模型上,ARC-C提升4.35,BoolQ提升3.73,MMLU提升1.86。
* (iii) 验证困惑度在所有领域均有一致的降低: 如表5所示,在7亿和15亿规模下,MoDA在所有10个验证领域都降低了PPL,平均PPL分别从15.61降至15.46和从13.67降至13.47。

MoDA与层数分析

表6 MoDA在更深(48层)和更浅(24层)模型设置下的层数分析。我们比较了普通注意力(OLMo2)和不同MoDA变体,在pre-norm和post-norm配置下的表现。模型使用相同的数据配方训练,我们报告了参数数量、FLOPs和FineWeb-Edu验证损失。在两种深度设置下,引入深度KV都能持续改善验证损失,而增加额外的FFN KV投影则以适度的计算开销带来进一步的增益。

实验设置: 在较小模型上(宽度384)进行层数实验,比较了浅层(24层)和深层(48层)模型,并同时考虑了pre-norm和post-norm配置。

结论:
* (i) 深度KV在不同层数下均能持续改善验证损失: 无论是48层还是24层模型,添加深度KV都降低了验证损失。
* (ii) 在深层模型中,post-norm从深度KV中获益更多: 在48层模型中,post-norm配置下的损失降低幅度(0.0409)远大于pre-norm(0.0041),表明深度KV在post-norm配置下对深层模型的优化影响更强。
* (iii) 额外的FFN KV投影能带来额外增益: 在已添加深度KV的基础上,再增加FFN的KV投影能进一步降低损失。

MoDA注意力可视化分析


图5 混合深度注意力(MoDA)热力图,采用组合式softmax公式。列对应均匀采样的层{0, 11, 23, 35},行对应每层中随机选择的头。第一列显示仅对序列KV的注意力,而其他列显示拼接的序列KV | 深度KV;红色虚线标记了两个KV块之间的边界。在各层和各头之间,大量的注意力权重被持续分配给深度KV块,表明MoDA在标准序列注意力之外有效地利用了深度信息。

实验设置: 可视化7亿参数模型在4000亿词元训练后的注意力热图。

结论:
* 模型主动检索跨层深度信息: 在中后期层,有相当大且持续的注意力权重被分配给了深度KV块,这表明模型不仅仅依赖于序列局部上下文。
* 注意力模式呈现互补性: 具有尖锐对角线序列注意力的头仍然会将部分概率分配给深度槽,而注意力分布较广的头则更倾向于依赖深度KV。
* 可能缓解“注意力沉洞(attention sink)”现象: MoDA的注意力模式与典型的注意力沉洞行为不同。它没有将大量概率集中在少数固定的“沉洞”位置,而是更广泛地分布在序列和深度槽位上,这表明MoD可能改变了长上下文设置下注意力权重的分配方式。

MoDA效率分析

表7 在固定配置下对提出的核实现策略进行的消融实验。每一行递增地启用一个优化组件:(1)朴素的PyTorch基线,(2)Flash兼容的深度KV布局,(3)Flash兼容且块感知的深度KV布局,以及(4)Flash兼容、块感知且组感知的索引。我们报告了端到端的“前向和后向”运行时间(毫秒),越低越好,最佳性能用粗体标出。实验在单个A100 GPU上以bfloat16进行,设置固定为B=1, T=1024, G=8, Hq=64, Hk=8, d=64, L=64, C=64。

实验设置: 在固定配置下,对硬件感知核的三个主要优化(Flash兼容布局、块感知布局、组感知索引)进行增量消融实验。

结论:
* (i) Flash兼容布局带来数量级加速: 与朴素PyTorch实现相比,仅使用Flash兼容布局就将运行时间从2128.9ms降至13.102ms,速度提升约162.5倍。
* (ii) 块感知布局进一步提升效率: 在Flash兼容的基础上,增加块感知布局将运行时间从13.102ms降至6.286ms,减少了52.0%。
* (iii) 组感知索引是充分利用GQA机制的关键: 最后增加组感知索引将运行时间从6.286ms降至1.460ms,带来了额外的4.31倍加速。
* 总体效果: 三项优化结合,相比朴素PyTorch基线实现了约1458倍的端到端加速。

A5 结论

本文提出了MoDA,一种统一的、深度感知的注意力机制,旨在改善LLM的深度信息聚合能力,并缓解由优化困难和信息稀释带来的深度效率差距。为了使其在长上下文场景下高效运行,我们进一步开发了一个硬件感知的融合核,该核利用统一的在线softmax状态、块感知的深度KV布局和组感知的索引。在7亿和15亿参数模型上的实验表明,MoDA在适度的开销下,在困惑度和下游任务性能上均取得了一致的提升。这些结果表明,显式地从历史深度信息中进行检索是扩展Transformer深度的一种实用且有效的原语。我们将发布MoDA的完整实现,并希望它能成为开源社区构建更强大LLM的基础。除了语言建模,MoDA是架构无关的,可以轻松集成到多模态智能、视觉理解和世界模型等领域,因为Transformer在这些领域正被越来越多地采用。我们相信,有原则的深度感知信息聚合将在这些不同领域带来广泛而持久的益处。

A7 补充细节

6.1 通过高级CUDA工程扩展MoDA以适应工业级训练

通过高级CUDA工程扩展MoDA以适应工业级训练。尽管当前的硬件感知MoDA核已经达到了与FlashAttention 2相当的竞争力,但对于工业级规模的训练(例如万亿参数模型),这还不是终点。在大型生产运行中,额外的CUDA工程仍然至关重要,包括改进的内存调度、更深层次的计算流水线化,以及融合注意力核与分布式通信之间更紧密的重叠。这些优化不会改变MoDA的算法行为,但可以进一步减少内存停顿和核启动开销,提高端到端吞吐量,并增加集群级别的训练效率。因此,我们将未来的CUDA优化视为将MoDA从一个高效的研究算子转变为工业级LLM训练的稳健原语的一个重要方向。

6.2 通过有界深度KV槽缓存缓解内存瓶颈

通过有界深度KV槽缓存缓解内存瓶颈。当扩展到非常深的网络时,缓存所有历史层的所有深度KV状态会引入巨大的内存和带宽开销。该成本随深度线性增长,并可能成为长上下文训练和服务中的主要瓶颈。因此,完整的深度KV缓存在工业规模上越来越难以维持。

有界深度KV槽缓存的策略。一个实际可行的方向是使用一个固定大小的深度KV槽缓冲区。每个查询只关注一个有界的槽集合,而不是存储所有的深度KV条目。槽预算固定为S,其中$S \ll L$,系统动态决定保留哪些深度KV条目。有两种自然的策略。一种是动态选择,即根据效用对候选的深度KV条目进行评分,并保留前S个条目。另一种是滑动窗口策略,即保留最近的深度KV条目并驱逐较旧的条目。也可以使用混合设计,其中一部分槽保留给最近的条目,其余的用于高分的全局记忆。

有界缓存设计的优势与挑战。这种设计将有效的深度内存从无界缓存变为有界缓存。内存和带宽项从依赖深度的扩展转变为依赖槽的扩展。它还为融合核的实现提供了稳定的张量形状。在实践中,关键的挑战是槽分配的质量。未来的工作应该研究如何将选择策略与MoDA联合训练,以及如何在固定的槽预算下平衡质量、延迟和硬件效率。