Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models

标题: Lightning Attention-2:处理大语言模型中无限序列长度的免费午餐
作者: Zhen Qin, Weigao Sun, Dong Li, Xuyang Shen, Weixuan Sun, Yiran Zhong

A1 主要贡献

核心问题
Transformer架构的计算复杂度随着输入序列的长度呈二次方增长,这使得处理极长序列变得具有挑战性。尽管线性注意力理论上通过核技巧可以将计算复杂度从O(n²)降低到O(n),但在实际应用中,尤其是在因果(causal)设定下,由于需要进行累积求和(cumsum)操作,其在硬件上的实际计算效率远未达到理论上的线性加速效果。现有方法,如Lightning Attention-1,虽然通过IO感知的分块技术提升了效率,但其根本复杂度仍为O(n²d),未能利用线性注意力的内在计算特性。

研究目标
本文旨在解决线性注意力在因果设置中因累积求和操作而无法发挥其理论计算优势的核心问题。研究目标是提出一种新的线性注意力实现,使其能够在处理无限长序列时,保持恒定的训练和推理速度,且不牺牲性能,从而真正实现线性注意力的理论潜力。

创新点
本文的主要贡献是提出了 Lightning Attention-2,这是首个能够让线性注意力在因果设定下实现其理论计算优势的实现。其核心创新点如下:
1. 分治策略:采用“分而治之”的思想,将注意力计算分解为块内(intra-block)块间(inter-block)两个独立的部分。
2. 混合计算模式
* 对于块内部分,使用传统的注意力计算机制(即Q、K、V的左乘积),这在块尺寸较小时是高效的。
* 对于块间部分,则利用线性注意力的核技巧(即先计算K和V的右乘积),这能有效利用其线性的计算和内存效率。
3. IO感知的硬件优化:通过分块(tiling)技术,将数据从慢速的GPU高带宽内存(HBM)加载到快速的片上SRAM中进行计算,最小化了内存读写开销。整个算法在Triton中实现,使其对硬件友好且IO感知。
4. 恒定的计算速度:通过上述设计,Lightning Attention-2 实现了在固定内存消耗下,其计算速度与输入序列长度无关。如下图1所示,随着序列长度的增加,使用FlashAttention-2和Lightning Attention-1的模型训练速度(TGS,每秒每GPU处理的token数)显著下降,而Lightning Attention-2则保持稳定。

图1. FlashAttention与Lightning Attention在不同序列长度和模型大小下的速度对决。上图比较了使用FlashAttention-2的LLaMA、使用Lightning Attention-1的TransNormerLLM以及使用Lightning Attention-2的TransNormerLLM在400M、1B和3B三种模型大小下的训练速度(TGS)。显而易见,Lightning Attention-2无论序列长度如何增加,都表现出一致的训练速度。相反,其他方法的训练速度随着序列长度的增加而显著下降。
图1. FlashAttention与Lightning Attention在不同序列长度和模型大小下的速度对决。上图比较了使用FlashAttention-2的LLaMA、使用Lightning Attention-1的TransNormerLLM以及使用Lightning Attention-2的TransNormerLLM在400M、1B和3B三种模型大小下的训练速度(TGS)。显而易见,Lightning Attention-2无论序列长度如何增加,都表现出一致的训练速度。相反,其他方法的训练速度随着序列长度的增加而显著下降。

A3 相关工作

2.1. 线性注意力

基本思想与挑战。线性Transformer架构用不同的近似方法取代了Softmax注意力机制【18, Katharopoulos et al. Transformers are rnns: Fast autoregressive transformers with linear attention. 2020a; 7, Choromanski et al. Rethinking attention with performers. 2020; 30, Peng et al. Random feature attention. 2021; 33, Qin et al. cosformer: Rethinking softmax in attention. 2022b; 32, Qin et al. The devil in linear transformer. 2022a】。其核心思想是利用“核技巧”来加速注意力矩阵的计算,即先计算键(keys)和值(values)的乘积,从而避免进行n × n的矩阵乘法。已有多种方法被提出来替代softmax操作,例如,Katharopoulos等人【18, Katharopoulos et al. 2020a】使用1 + elu激活函数,Qin等人【33, Qin et al. 2022b】利用余弦函数来近似softmax的性质,而Ke等人【20, Ke et al. Rethinking positional encoding in language pre-training. 2021】和Zheng等人【63, Zheng et al. Linear complexity randomized self-attention mechanism. 2022; 64, Zheng et al. Efficient attention via control variates. 2023】则利用采样策略直接模仿softmax操作。尽管线性注意力的理论复杂度为O(nd²),但在因果注意力场景中,由于需要进行累积求和(cumsum)操作,其实际计算效率会显著下降【16, Hua et al. Transformer quality in linear time. 2022】。

2.2. IO感知的注意力

系统级优化。FlashAttention系列【11, Dao et al. FlashAttention: Fast and memory-efficient exact attention with IO-awareness. 2022; 10, Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. 2023】专注于对标准注意力算子在GPU平台上的高效实现进行系统级优化,其有效性已得到广泛验证。该方法采用分块(tiling)策略,以最小化GPU高带宽内存(HBM)和片上SRAM之间的内存读写量。

现有线性注意力优化的局限性。为了解决线性注意力在因果设置下计算缓慢的问题,Lightning Attention 1【35, Qin et al. Scaling transnormer to 175 billion parameters. 2023b】采用了类似FlashAttention-1/2的方法,即将输入Q、K、V分割成块,从慢速的HBM传输到快速的SRAM,然后计算这些块的注意力输出,并最终累加结果。尽管这种方法比PyTorch的实现效率高得多,但它没有利用线性注意力固有的计算特性,其理论复杂度仍然是O(n²d)。

2.3. LLM中的长序列处理

相对位置编码(RPE)。一种广泛采用的应对长度外推挑战的策略是集成相对位置编码(RPE)技术【48, Su et al. Roformer: Enhanced transformer with rotary position embedding. 2021; 36, Qin et al. Linearized relative positional encoding. 2023c】,这些技术策略性地将注意力引向邻近的token。ALiBi【31, Press et al. Train short, test long: Attention with linear biases enables input length extrapolation. 2022】在注意力机制中使用线性衰减偏置,以减轻远处token的影响。Roformer【48, Su et al. 2021】引入了一种新颖的旋转位置嵌入(RoPE)方法,该方法在社区中被广泛采用,有效利用位置信息进行基于Transformer的语言模型学习。Kerple【5, Chi et al. Kerple: Kernelized relative positional embedding for length extrapolation. 2022】在RPE中探索了移位不变的条件正定核,并引入了一系列旨在增强长度外推属性的核,其中ALiBi被认为是其一个实例。此外,Sandwich【6, Chi et al. Dissecting transformer length extrapolation via the lens of receptive field analysis. 2023】提出了一个解释ALiBi机制的假设,并通过将该假设融入正弦位置嵌入中进行了经验验证。Qin等人【38, Qin et al. Exploring transformer extrapolation. 2024】探索了加性相对位置编码具备外推能力的充分条件。

直接扩展上下文窗口。一些工作也尝试直接增加上下文窗口大小,而不是研究Transformer的长度外推能力。Chen等人【4, Chen et al. Extending context window of large language models via positional interpolation. 2023】引入了位置插值(PI),将基于RoPE的预训练大语言模型(如LLaMA模型)的上下文窗口大小扩展至32768,且只需极少的微调(1000步以内)。StreamingLLM【58, Xiao et al. Efficient streaming language models with attention sinks. 2023】提出利用注意力池(attention sink)现象,通过保留初始token的键和值信息来显著恢复窗口注意力的性能。随着序列变长,性能会下降。这些方法只能在微调或测试阶段扩展序列长度,而我们的方法允许从头开始以无额外成本的方式训练长序列模型。

A2 方法细节

3.1. 预备知识

线性注意力公式回顾。我们首先回顾线性注意力的公式,然后介绍我们提出的Lightning Attention-2。在TransNormer【32, Qin et al. 2022a】的NormAttention中,注意力计算不同于传统的Transformer结构【56, Vaswani et al. Attention is all you need. 2017】,它避免了昂贵的softmax和缩放操作。NormAttention机制可以表示如下:

其中Q, K, V ∈ R^(n×d) 分别是查询、键和值矩阵,n表示序列长度,d表示特征维度。为了利用右矩阵乘法固有的计算效率,根据矩阵乘法的性质,上述方程可以无缝且数学上等价地转换为其线性变体:

线性注意力的优势与挑战。这种线性形式使得训练过程具有O(nd²)的复杂度,相对于序列长度而言是高效的。此外,采用线性注意力确保了在推理过程中计算复杂度恒为O(d²),与序列长度无关,从而能够对无限长的序列进行推理。这是通过递归地更新KᵀV而无需重复计算整个注意力矩阵实现的。相比之下,标准softmax注意力的推理过程计算复杂度为O(md²),其中m表示token的索引。然而,在处理因果预测任务时,右乘积的有效性会受到影响,导致需要计算累积和(cumsum)【16, Hua et al. 2022】。这个障碍阻碍了高效并行计算的潜力。因此,我们在Lightning Attention-1中仍然坚持使用传统的左矩阵乘法。这也是我们引入Lightning Attention-2的动机,它专门为解决右乘积在此类情境下面临的挑战而设计。

3.2. Lightning Attention-2

分块与内存利用策略。Lightning Attention-2在其整个计算过程中采用了分块(tiling)方法。鉴于GPU内高带宽内存(HBM)和静态随机存取存储器(SRAM)之间内存带宽的巨大差异,Lightning Attention-2对它们的利用采取了独特的策略。在每次迭代i中,矩阵$Q_i, K_i, V_i$被分割成块,随后传输到SRAM进行计算。块内(intra-block)和块间(inter-block)的操作是分开的,块内操作采用左乘积,而块间操作则利用右乘积。这种方法优化地利用了与右乘积相关的计算和内存效率,从而提高了整体执行速度。中间的激活值KV被迭代地保存和累积在SRAM中。随后,块内和块间的输出在SRAM内相加,结果被写回HBM。此方法旨在利用每种内存组件的独特优势,优化计算工作流。Lightning Attention-2的结构框架在图2中有详细说明。

图2. Lightning Attention-2的结构框架在其算法示意图中得以详细展示。在第i次迭代期间,矩阵Qi, Ki, Vi的分块从高带宽内存(HBM)传输到静态随机存取存储器(SRAM)。在SRAM内部,块内输出Ointra和块间输出Ointer被独立计算,然后更新KV矩阵。随后,最终输出Oi(即Ointra和Ointer之和)从SRAM写回HBM。
图2. Lightning Attention-2的结构框架在其算法示意图中得以详细展示。在第i次迭代期间,矩阵Qi, Ki, Vi的分块从高带宽内存(HBM)传输到静态随机存取存储器(SRAM)。在SRAM内部,块内输出Ointra和块间输出Ointer被独立计算,然后更新KV矩阵。随后,最终输出Oi(即Ointra和Ointer之和)从SRAM写回HBM。

算法与推导。Lightning Attention-2实现的复杂细节通过算法1(前向传播)和算法2(反向传播)进行了阐述。这些算法旨在封装Lightning Attention-2中核心的精细计算过程。此外,我们提供了一个全面的推导,以帮助更深入地理解Lightning Attention-2。这些推导针对前向传播和反向传播分别系统地呈现,有助于对底层机制的透彻理解。

3.2.1. 前向传播

公式推导。为简化推导,我们忽略公式(2)中的Norm(·)算子。在Lightning Attention-2的前向传播过程中,第t个输出可以表示为:

递归形式。以上方程可以递归地重写为:
我们首先定义:

分块形式。给定$KV_t$,第(t + 1)个块的输出,即$tB + r$(其中$1 \leq r \leq B$),为:

以矩阵形式重写,我们得到:


其中:

分块实现。为了进行分块操作,我们将方程写成块的形式。给定总序列长度n和块大小B,X被划分为$T = n/B$个块$\{X_1, X_2, \dots, X_T\}$,每个块的大小为B × d,其中X ∈ {Q, K, V, O}。
其中:


第(t + 1)个块的KV可以写为:

Lightning Attention-2前向传播的完整表达式见算法1。

算法 1 Lightning Attention-2 前向传播

输入: Q, K, V ∈ R^(n×d), 衰减率 λ ∈ R+, 块大小 B.
将 X 分割成 T = n/B 个块 X1, X2, ..., XT,每个块大小为 B × d,其中 X ∈ {Q, K, V, O}。
初始化掩码 M ∈ R^(B×B),其中 Mij = λ^(i−j),如果 i ≥ j,否则为 0。
初始化 Λ = diag{λ, λ², ..., λ^B} ∈ R^(B×B)。
初始化 KV = 0 ∈ R^(d×d)。
for 1 ≤ i ≤ T do
    从HBM加载 Qi, Ki, Vi ∈ R^(B×d) 到片上SRAM。
    在片上计算 O_intra = [(Qi * K_i^T) ⊙ M] * Vi。
    在片上计算 O_inter = Λ * Qi * (KV)。
    在片上计算 KV = λ^B * KV + (λ^B * Λ^(-1) * Ki)^T * Vi。
    将 Oi = O_intra + O_inter 作为 O 的第i个块写回HBM。
end for
返回 O。
3.2.2. 反向传播

梯度推导。对于反向传播,我们考虑逆过程。首先给定$do_t$,我们有:

递归形式。通过将$dkv_t$写成递归形式,我们得到:

分块形式。为了便于理解分块,我们以块的形式来考虑上述方程。给定总序列长度n和块大小B,X被划分为$T = n/B$个块$\{X_1, X_2, \dots, X_T\}$,每个块大小为B × d,其中X ∈ {Q, K, V, O, dO}。
我们首先定义:


那么对于第(t + 1)个块,即$tB + r, 0 \leq r < B$,我们有:

以矩阵形式表示,我们有:

梯度计算。由于$dK_t$的递归从t + 1步进到t,给定$KV_{t+1}$,第t个块的$dK_t$,即在位置$(t-1)B + r, 0 < r \leq B$处,为:


以矩阵形式,我们得到:

dV梯度。考虑第t个块的$dV_t$,即在位置$(t-1)B + r, 0 < r \leq B$处,我们有:


以矩阵形式,我们得到:

dKV递归关系。最后,$dKV_t$的递归关系为:


算法2更详细地描述了Lightning Attention-2的反向传播过程。

算法 2 Lightning Attention-2 反向传播

输入: Q, K, V, dO ∈ R^(n×d), 衰减率 λ ∈ R+, 块大小 B。
将 X 分割成 T = n/B 个块 X1, X2, ..., XT,每个块大小为 B × d,其中 X ∈ {Q, K, V}。
将 dX 分割成 T = n/B 个块 dX1, dX2, ..., dXT,每个块大小为 B × d,其中 X ∈ {Q, K, V, O}。
初始化掩码 M ∈ R^(B×B),其中 Mij = λ^(i−j),如果 i ≥ j,否则为 0。
初始化 Λ = diag{λ, λ², ..., λ^B} ∈ R^(B×B)。
初始化 KV = 0, dKV = 0 ∈ R^(d×d)。
for i = 1, ..., T do
    从HBM加载 Ki, Vi, Oi, dOi ∈ R^(B×d) 到片上SRAM。
    在片上计算 dQ_intra = [(dOi * V_i^T) ⊙ M] * Ki。
    在片上计算 dQ_inter = Λ * dOi * (KV)^T。
    在片上计算 KV = λ^B * KV + (λ^B * Λ^(-1) * Ki)^T * Vi。
    将 dQi = dQ_intra + dQ_inter 作为 dQ 的第i个块写回HBM。
end for
for i = T, ..., 1 do
    从HBM加载 Qi, Ki, Vi, Oi, dOi ∈ R^(B×d) 到片上SRAM。
    在片上计算 dK_intra = [(dOi * V_i^T) ⊙ M]^T * Qi。
    在片上计算 dK_inter = (λ^B * Λ^(-1) * Vi) * (dKV)^T。
    在片上计算 dV_intra = [(Qi * K_i^T) ⊙ M]^T * dOi。
    在片上计算 dV_inter = (λ^B * Λ^(-1) * Ki) * dKV。
    在片上计算 dKV = λ^B * dKV + (Λ * Qi)^T * dOi。
    将 dKi = dK_intra + dK_inter, dVi = dV_intra + dV_inter 分别作为 dK, dV 的第i个块写回HBM。
end for
返回 dQ, dK, dV。

讨论。最近的一种方法,GLA【59, Yang et al. Gated linear attention transformers with hardware-efficient training. 2023】使用带有数据依赖衰减的线性注意力来建模序列。其分块并行算法(chunk-wise Block-Parallel Algorithm)采用了分块和IO感知的概念。然而,与Lightning Attention-2不同,它对每个块使用并行计算,这导致了更高的内存使用量。RetNet【53, Sun et al. Retentive network: A successor to transformer for large language models. 2023b】在结构上与TransNormerLLM【35, Qin et al. 2023b】非常相似,并使用了分块保持算法(chunk-wise retention algorithm)。该算法与Lightning Attention-2的前向传播相当,但没有考虑IO感知或反向传播。

A4 实验环境与结果

实验环境

实验结果

1. 注意力模块评估
* 实验内容: 在单张A100 80G GPU上,比较了Lightning Attention-1、Lightning Attention-2和FlashAttention-2在不同序列长度下的前向和反向传播速度及内存使用情况。
* 实验结果 (图3):
* 速度: FlashAttention-2的运行时间随序列长度呈二次方增长。相比之下,Lightning Attention-2表现出线性增长,且随着序列长度增加,其速度优势愈发显著。
* 内存: 随着序列长度增加,Lightning Attention-2在内存使用方面也保持了显著优势。

图3. 速度和内存使用的比较分析:FlashAttention vs. Lightning Attention。上部分:不同序列长度下前向和反向传播的运行时间(毫秒)。下部分:不同序列长度下前向和反向传播的内存利用率。
图3. 速度和内存使用的比较分析:FlashAttention vs. Lightning Attention。上部分:不同序列长度下前向和反向传播的运行时间(毫秒)。下部分:不同序列长度下前向和反向传播的内存利用率。

2. 在大语言模型中的评估
* 性能评估 (表2, 图4):
* 在2K上下文的TransNormerLLM-0.4B模型上,使用Lightning Attention-2的版本与使用Lightning Attention-1的版本相比,性能仅有0.001的微小下降(见表2),表明其准确性得以保持。
* 在1B和3B参数规模上,与其他高效LLM(LLaMA-FA2, HGRN, TNN)相比,使用Lightning Attention-2的TransNormerLLM(TNL-LA2)在训练后获得了略低的损失,显示出有竞争力的性能(见图4)。

<center>表2. TransNormerLLM使用Lightning Attention-1和Lightning Attention-2的语言建模性能比较。</center>

图4. HGRN, TNN, 使用FlashAttention2的LLaMA以及使用Lightning Attention-2的TransNormerLLM的性能比较。对于1B模型,我们使用16×A800 80G GPU,每GPU批次大小为12;对于3B模型,我们扩展到32×A800 80G GPU,每GPU批次大小为30。训练上下文长度设置为2K。
图4. HGRN, TNN, 使用FlashAttention2的LLaMA以及使用Lightning Attention-2的TransNormerLLM的性能比较。对于1B模型,我们使用16×A800 80G GPU,每GPU批次大小为12;对于3B模型,我们扩展到32×A800 80G GPU,每GPU批次大小为30。训练上下文长度设置为2K。

<center>表1. LLaMA (FlashAttention2), TransNormerLLM (Lightning Attention-1) 和 TransNormerLLM (Lightning Attention-2) 的效率比较。统计分析在2×A100 80G GPU上进行。该表报告了三种不同模型大小下,上下文范围从1K到92K的每GPU每秒token数(TGS)。OOM表示GPU内存不足。</center>

3. 15B模型基准测试
* 实验内容: 评估了包含150亿参数的TransNormerLLM-15B模型在常识推理和聚合基准上的性能,并与Pythia-12B进行比较。
* 实验结果 (表3):
* 常识推理: 在所有常识推理任务中,TransNormerLLM-15B的性能比Pythia-12B高出约2%。
* 聚合基准: 在C-Eval任务中,TNL-15B比Pythia-12B高出约2%。在MMLU和C-Eval的0-shot和5-shot测试中,其性能均超过了随机猜测的基线(25%)。

<center>表3. 在常识推理和聚合基准上的性能比较。TNL-LA2: 使用Lightning Attention-2的TransNormerLLM。PS: 参数规模(十亿)。T: tokens(十亿)。HS: HellaSwag。WG: WinoGrande。</center>

A5 结论

本文介绍了Lightning Attention-2,一种开创性的线性注意力实现,它在因果设置下成功地利用了其理论上的计算优势。我们的方法采用了“分而治之”和分块技术,通过将计算分离为块内和块间组件,有效解决了当前线性注意力算法的局限性,特别是与累积求和相关的挑战。通过这种方式,我们充分利用了GPU硬件的潜力,确保了效率。我们跨越不同模型大小和序列长度的广泛实验表明,Lightning Attention-2不仅能保持恒定的训练速度,不受输入序列长度的影响,而且在速度和准确性方面也优于现有的SOTA注意力机制。这一突破对大语言模型的未来具有深远影响,特别是那些需要处理长序列的模型。展望未来,我们计划将序列并行性与Lightning Attention-2相结合,旨在促进超长序列的训练,从而有效克服现有的硬件限制。