Jenga: Enhancing LLM Long-Context Fine-tuning with Contextual Token Sparsity

A1 主要贡献

随着对综合文档分析【13, 26】、扩展多轮对话【76】和复杂代码库处理【51, 54】需求的增长,具有更大上下文窗口的大语言模型(LLM)正成为人工智能应用不可或缺的一部分。然而,这些模型通常在固定的上下文窗口(如Llama2的4K token限制【67】)下进行预训练,当输入超过此限制时性能会显著下降。

近期研究表明,通过在更长序列上进行微调,可以扩展预训练LLM的上下文窗口,但这带来了巨大的资源挑战,特别是内存消耗。其中,激活值(Activations)是主要的内存瓶颈,因为它与序列长度成正比,而非模型参数。现有的高效微调技术未能有效解决这一问题:
1. 参数高效微调(PEFT)方法:如LoRA【29】,虽然减少了参数更新的开销,但并未优化激活内存。
2. 稀疏机制:如LongLoRA【12】,虽然提高了计算效率,但由于“影子激活(Shadowy Activation)”现象,并未减少激活内存。影子激活指的是,一旦一个token参与计算,无论其使用程度如何,其激活值都会被保留在内存中,这导致现有在隐藏维度上进行稀疏化的方法无法减少内存占用。

微调方法 模型参数 (GB) 优化器状态 (GB) 激活值 (GB) 总计 (GB)
Vanilla 13.0 52.0 5.3 70.3
LoRA 13.0 0.3 5.4 18.7
LongLoRA 13.0 0.3 5.5 18.8

Table 1: 不同微调方法的内存占用(GB)比较。LoRA和LongLoRA分别是PEFT和基于稀疏性方法的代表(S = 4K)。

图1:影子激活的图示。(a) LoRA执行全注意力计算,未引入稀疏性。(b) LongLoRA采用两种移位的局部注意力模式(红色和绿色)来近似全注意力,目标是隐藏维度级别的稀疏性。(c) JENGA采用token级别的稀疏性,最小化token的参与度,从而实现比其他方法更优的激活内存节省。
图1:影子激活的图示。(a) LoRA执行全注意力计算,未引入稀疏性。(b) LongLoRA采用两种移位的局部注意力模式(红色和绿色)来近似全注意力,目标是隐藏维度级别的稀疏性。(c) JENGA采用token级别的稀疏性,最小化token的参与度,从而实现比其他方法更优的激活内存节省。

核心问题与研究目标
本文旨在解决长上下文微调中激活内存成为瓶颈的问题,提出一个能够同时优化内存和计算效率的新型LLM微调方案。

核心洞察与创新点
本文的核心洞察是利用自然语言中存在的固有冗余性,特别是在长上下文场景中。文章提出了一种新的token级稀疏机制,称为上下文Token稀疏性(Contextual Token Sparsity)。该机制的关键在于识别并只保留信息量最丰富的token,从而以最少的token参与度进行长上下文微调。这种稀疏性具有两个特点:1)它在不同的输入和层之间动态变化;2)它允许直接减少参与计算的token数量,从而根本上解决“影子激活”问题。

为实现这一目标,本文提出了JENGA,一个高效的LLM长上下文微调系统,包含三项关键技术:
1. 信息驱动的Token消除(Information-driven Token Elimination):通过评估token嵌入的信息量,动态识别并排除不同输入和层中的冗余token。
2. 上下文感知的模式预测(Context-aware Pattern Prediction):利用训练好的预测器以极小的开销近似token的稀疏模式。
3. 高性能核优化(High-performance Kernel Optimization):采用无置换(permutation-free)和分段式(segment-based)策略来提升系统性能。

JENGA作为一个端到端的微调系统,与各种LLM架构和其他优化技术兼容。实验表明,JENGA能够将内存消耗降低高达1.93倍,并实现高达1.36倍的加速,性能优于当前最先进的微调系统。

A3 背景知识与关键洞察

2.1 高效LLM微调

段要点:现有高效微调方法的局限性。微调是使预训练LLM适应下游应用的关键步骤,尤其是在扩展上下文窗口以支持长上下文场景时。然而,处理长文本序列会极大地增加计算和内存需求。参数高效微调(PEFT)方法,如低秩适应(LoRA)【29】,通过只更新模型参数的一个子集来减少优化器状态的内存需求,但并未解决激活内存的瓶颈。另一研究方向是利用注意力机制中的内在稀疏性【17】,通过各种稀疏模式来近似标准的密集注意力。例如,LongLoRA【12】将LoRA与一种新的移位稀疏模式相结合,用两组移位的稀疏局部注意力替代全局密集注意力,从而减少了训练时间。然而,这些方法都未能充分解决长上下文微调中的内存瓶颈,因为激活值仍然与序列长度成比例缩放。

2.2 分析:微调内存分解

段要点:Vanilla微调的内存构成。Vanilla微调的内存消耗主要包括模型状态剩余状态。模型状态包括参数、梯度和优化器状态,在混合精度训练【47】中,对于一个参数量为ψ的模型,这部分内存大约为16ψ,是固定的。剩余状态中,激活值(在前向传播过程中存储的中间结果和梯度)占了相当大的部分,其内存需求不仅依赖于模型参数大小ψ,还与输入批次大小b和序列长度s成比例。在长上下文场景中,激活值的内存消耗很容易超过模型状态,成为主要的内存瓶颈。

图2:LLM混合精度微调期间的内存分解。与固定的模型状态相比,激活值随输入批次大小和序列长度扩展,成为长上下文场景中的主要瓶颈。
图2:LLM混合精度微调期间的内存分解。与固定的模型状态相比,激活值随输入批次大小和序列长度扩展,成为长上下文场景中的主要瓶颈。

序列长度 (s) 4K 8K 16K 32K
激活内存/模型状态 1.0× 2.1× 4.2× 8.3×

Table 2: GPT-3 175B【8】在不同序列长度下的激活内存使用情况(与模型状态相比)。

段要点:LoRA的内存分析。LoRA通过冻结预训练模型参数并只更新注入的低秩矩阵,显著减少了梯度和优化器状态的内存消耗。然而,这并未减少激活内存。如图3所示,LoRA甚至会增加激活内存的使用,因为可训练的低秩矩阵深度嵌入在模型结构中,其梯度计算需要与vanilla微调几乎相同的遍历过程(遵循反向传播的链式法则)。其他PEFT方法也存在类似问题。

图3:一个用于展示PEFT激活内存使用的LoRA示例。除了共享部分,LoRA需要存储比vanilla更多的激活值。
图3:一个用于展示PEFT激活内存使用的LoRA示例。除了共享部分,LoRA需要存储比vanilla更多的激活值。

段要点:LongLoRA与影子激活现象。LongLoRA在LoRA的基础上提出了移位稀疏注意力(S2-Attn)机制。如图1(b)所示,S2-Attn将输入token分为两组,并在每组内独立执行注意力计算,通过对一半的注意力头进行token移位来实现跨组信息交换。虽然LongLoRA在计算上比LoRA更节省,但它未能提供额外的内存减少。这是因为LongLoRA及其他基于稀疏性的技术仅在token嵌入的隐藏维度上操作。虽然这些方法减少了与单个token相关的计算负担,但它们无法将token完全从计算中排除。然而,一旦一个token参与了计算,无论其使用程度如何,它的激活值都会被存储。我们将此现象称为影子激活(Shadowy Activation),类似于物体通过阻挡光线投下影子。在这里,“阻挡”指的是一个token对于token 1来说是不必要的(其激活可以被丢弃),但对于token 2来说是必需的(其激活必须被保留)。影子激活的存在使得激活内存瓶颈仍未得到解决。

3 洞察:上下文Token稀疏性

段要点:核心洞察——自然语言的冗余性。为了解决§ 2.2中讨论的挑战,我们引入了一个新的视角,即探索和利用LLM长上下文微调过程中的token级稀疏机制。我们的方法基于一个直观而深刻的观察:自然语言本质上是冗余的【60, 71】。具体来说,多项研究【38, 42, 61, 63, 72, 82】指出,标准的全注意力可以被有效地近似为仅关注一小部分查询、键和值之间的重要交互,这只涉及序列中的一部分token。值得注意的是,这种冗余性在长上下文场景中更为显著。如表3所示,随着序列长度的增加,注意力分数中重要交互的比例下降。这些洞察为通过识别并仅保留信息最丰富的token来优化LLM长上下文微调提供了机会。通过直接减少token的参与,影子激活的限制被自然地缓解,使得激活内存能够获得与计算节省相当的好处。

图4:跨不同模型、数据集、层和序列的注意力分数可视化(颜色越深表示值越高)。
图4:跨不同模型、数据集、层和序列的注意力分数可视化(颜色越深表示值越高)。

段要D点:上下文Token稀疏性的两个特征。我们将这种在LLM长上下文微调中的新颖稀疏机制称为上下文Token稀疏性(Contextual Token Sparsity)。通过全面评估,我们确定了该机制的两个关键独特特征:(1)逐Token的(Token-wise)。如图4所示,在不同模型和数据集中观察到注意力分数的网格状分布,证实了序列中的token嵌入具有不同的重要性水平。这一洞察使得排除价值较低的token成为可能,自然地导致内存使用和计算开销的减少。(2)上下文相关的(Contextual)。图4进一步表明,有价值token的分布会根据输入文本动态变化,并且在模型的不同层之间也存在差异,即使对于给定的模型和数据集也是如此。这种动态性强调了需要一个能够在运行时实时准确识别和高效利用这种稀疏性的系统。作为我们系统的基石,上下文token稀疏性首次将token级稀疏性引入到长上下文微调中,使LLM能够在有效参与更少token的情况下处理更大的上下文窗口。

序列长度 2K 4K 8K 16K
稀疏率 88.3% 91.5% 96.2% 98.7%

Table 3: 注意力分数在不同序列长度下的稀疏率(低于最大值0.3的比例)。

A2 方法细节

4.1 JENGA概览

段要点:JENGA系统概览。我们提出了JENGA,一个旨在通过系统地探索和利用上下文token稀疏性来增强LLM长上下文微调的高效系统。图5展示了JENGA的概览,它建立在三个关键技术之上:
* 信息驱动的Token消除(§ 4.2):为了确定一个token是否冗余,我们首先为给定token的信息量建立一个正式定义。基于此定义,JENGA利用一种基于分数的算法,在注意力块内动态识别和消除冗余token。该算法以分块方式执行,并通过采用层级特定的阈值进一步优化,确保了有效性和效率。此外,我们将此方法扩展到MLP块,以确保模型各组件的一致性。
* 上下文感知的模式预测(§ 4.3):JENGA采用一种基于神经网络的方法来预测token稀疏模式,绕过了计算成本高昂的全注意力分数。一旦充分训练,这些预测器可以根据上下文输入准确地近似token信息量。为了最小化预测器带来的开销,JENGA引入了一种弹性尺寸变换技术,优化了内存和计算效率。
* 高性能核优化(§ 4.4):JENGA深入研究了核级别的优化以最大化系统性能。为了最小化不必要的全局内存移动,JENGA引入了一种无置换策略,将token选择、token填充和残差相加直接融合到计算流水线中。此外,JENGA开发了一种基于分段的梯度计算方法,有效缓解了长上下文场景下的激活内存峰值。通过将这些核级别的优化与算法设计紧密结合,JENGA充分利用了上下文token稀疏性来增强LLM长上下文微调。

图5:JENGA概览。在每一层,token嵌入首先被划分为块并送入模式预测器(❷)。利用这些预测器预测的信息量分数,token消除算法(❶)有效地识别并仅保留信息最丰富的token进行处理。优化的核(❸)随后高效地执行token选择、计算、残差相加和填充。
图5:JENGA概览。在每一层,token嵌入首先被划分为块并送入模式预测器(❷)。利用这些预测器预测的信息量分数,token消除算法(❶)有效地识别并仅保留信息最丰富的token进行处理。优化的核(❸)随后高效地执行token选择、计算、残差相加和填充。

4.2 信息驱动的Token消除

段要点:引言与挑战。为了充分利用上下文token稀疏性,准确识别不同输入和层中的冗余token至关重要。丢弃信息丰富的token可能会影响模型精度,而保留过多的token则会导致资源效率低下。为了应对这些挑战,JENGA提出了一种信息驱动的算法,该算法在保持精度的同时动态识别和消除冗余token。

段要点:Token信息量定义。我们首先根据一个token与嵌入空间中其他token的交互来定义其信息量。在注意力机制中,注意力分数$S_{attn}$通常用于量化token之间的交互【25, 42, 72】。具体来说,注意力分数项$S_{ij}$,计算为$Q_i K_j$,表示token i和token j之间的交互。受此启发,我们通过考虑一个token与长上下文序列中所有其他token的交互来定义其信息量$I(T)$:

公式1
公式1

其中,求和操作聚合了序列中所有token i的注意力分数,排除了token j自身。

段要点:分块消除策略。在token信息量的概念基础上,下一步是根据它们的信息量分数消除冗余token。为了优化与硬件特性的对齐,JENGA以分块方式执行token消除。如图6所示,JENGA沿着token维度将注意力分数划分为多个分数块$B_S$。在跨注意力头聚合后,每个块内的最大值被选为该块的信息量分数$I(B_S)$。值得注意的是,在聚合过程中,JENGA只对正的注意力分数求和。由于注意力分数会经过softmax操作,负值对最终结果的影响微乎其微,但如果包含进来可能会抵消正值的影响。通过排除负分数,聚合过程保留了信息量的完整性,并确保了稳健的token消除。整个过程可以形式化为:

公式2
公式2

其中$S_{ij}^h$表示在注意力头h中token i和token j之间的注意力分数,$N_{head}$是注意力头的总数,作为缩放因子。

图6:Token消除算法。注意力分数首先在不同头之间聚合,并划分为多个分数块。每个分数块内的最大值被定义为其信息量分数,然后沿着列进行聚合。这些最终的分数与一个层级特定的阈值进行比较,以确定相应的token是否应被保留。
图6:Token消除算法。注意力分数首先在不同头之间聚合,并划分为多个分数块。每个分数块内的最大值被定义为其信息量分数,然后沿着列进行聚合。这些最终的分数与一个层级特定的阈值进行比较,以确定相应的token是否应被保留。

段要点:分块消除的有效性论证。这种分块token消除方法保留了长上下文序列的原始信息量,原因有二。首先,与长上下文输入序列相比,我们方法中使用的块大小相对较小。鉴于重要token在长上下文序列中的稀疏分布,大多数token块不包含重要token,可以安全地消除而不损害精度。其次,在重要token确实落入某个块的情况下,我们为token块设计的信息量分数$I(B_S)$能够有效地识别并保留这些块。这是因为重要token与不重要token的信息量分数$I(T)$差异足够大,使得重要token不会被块中的其他不重要token平均掉。因此,我们的方法确保任何包含重要token的块都会被保留,无论它包含多少这样的token。我们认为,维持微调精度比保留一些额外不重要token所带来的微不足道的计算成本更为关键。

段要点:层级特定阈值。token块$B_T$的信息量分数是通过聚合相应的分数块计算得出的,即$B_{T_n} = \sum_{m} B_{S_{mn}}$。然后将这些分数与一个阈值进行比较,以确定块内的token是否应被消除。特别地,JENGA通过采用层级特定的阈值进一步完善了token消除算法。关键洞察是LLM内部的不同层表现出不同的稀疏模式。如图7所示,token块的平均信息量分数在不同模型层之间差异很大,这表明在所有层上应用一个通用阈值是次优的。算法1概述了JENGA的方法,该方法首先根据分数分析为所有层初始化一个默认阈值,然后微调这些值以与每层的独特稀疏特性对齐。

图7:不同层中token块的平均信息量分数(归一化后)。
图7:不同层中token块的平均信息量分数(归一化后)。

段要点:扩展至MLP块。此外,我们将token消除扩展到MLP块。与注意力分数类似,JENGA利用MLP块内的中间激活来评估每个token的信息量。这种扩展可以看作是MLP块内被广泛研究的神经元稀疏性的一种变体【37, 41, 48】,确保了与各种MLP块结构的兼容性:对于基于ReLU的结构【2】,激活是ReLU层的输出;而对于基于SiLU的结构【20】,激活对应于门控投影(经过SiLU后)和上投影的逐元素乘积。这些技术使JENGA能够无缝地在不同的模型组件和配置中适应token消除。

算法 1: 层级特定阈值优化

输入: 模型层 L = {L1, L2, . . . , Ln}
输出: 层阈值 T = {T1, T2, . . . , Tn}
// 步骤 1: 阈值初始化
foreach layer Li ∈ L do
    // 对token块的平均分数
    Ti ← avg I(BT) ∀ BT ∈ Li;
// 步骤 2: 阈值微调
foreach layer Li ∈ L do
    // 用有限差分计算梯度
    Gi ← (acc(Ti + ε) - acc(Ti - ε)) / 2ε;
    // 根据梯度更新阈值
    Ti ← Ti + η · Gi;
return T

4.3 上下文感知的模式预测

段要点:引言与挑战。虽然确切的稀疏模式可以直接从全注意力分数中导出,但计算和存储这些分数的成本高得惊人,其复杂度随序列长度呈二次方增长。此外,由于上下文token稀疏性的动态性,最优稀疏模式只能在运行时确定,并随不同输入和层而变化。为了应对这些挑战,JENGA采用了一组轻量级神经网络作为预测器。通过将上下文嵌入作为输入,这些预测器能够准确而高效地推断稀疏模式。

段要点:基于神经网络的预测器。如图8所示,JENGA在每一层部署一对预测器,分别用于近似查询Q和键K的信息量分数。每个预测器由三个可训练的低秩矩阵组成,连续矩阵之间应用ReLU激活函数。预测器的输入是包含上下文信息的token嵌入X,这些嵌入被组织成块以与分块消除对齐。通过从每个块中提取代表性嵌入,预测器输出近似的信息量分数$\hat{I}(Q)$和$\hat{I}(K)$。然后将这些分数相乘,以近似注意力分数的信息量$\hat{I}(S_{attn})$:

公式3
公式3

当Q和K预测器训练良好时,$\hat{I}(S_{attn})$可以提供对准确信息量分数$I(S_{attn})$的紧密估计。

图8:模式预测过程。每一层都配备了两个预测器来分别近似Q和K。以token嵌入为输入(组织成token块),每个预测器为每个token块输出信息量分数,ˆI(BQm)和ˆI(BKn)。然后将这些分数相乘以计算分块注意力分数的信息量分数,ˆI(BSmn)。此外,采用弹性尺寸变换来独立地最小化每一层预测器的大小。
图8:模式预测过程。每一层都配备了两个预测器来分别近似Q和K。以token嵌入为输入(组织成token块),每个预测器为每个token块输出信息量分数,ˆI(BQm)和ˆI(BKn)。然后将这些分数相乘以计算分块注意力分数的信息量分数,ˆI(BSmn)。此外,采用弹性尺寸变换来独立地最小化每一层预测器的大小。

段要点:预测器的收敛与设计。在有限的训练数据集下,这些预测器可以快速收敛并表现出良好的预测性能,正如先前的研究【45, 62, 70】所证实的。特别地,JENGA中的预测器首先单独处理每个token,然后将这些单独的预测聚合成一个统一的结果。这种策略将预测器的尺寸限制在单个token块的维度,而不是整个长上下文序列,从而简化了设计并减轻了预测开销。

段要点:弹性尺寸变换。为了进一步减小预测器的大小,JENGA利用了一种弹性尺寸变换技术,该技术根据预测器各自的稀疏特性动态地修剪其中的神经元。受先前基于激活的修剪研究【37, 41, 69】的启发,该设计利用了ReLU激活函数的特性,该函数在预测器的中间激活中引入了大量的零元素。当一个激活元素为零时,其对应的神经元(即模型权重的行或列)变得不活跃,可以安全地被忽略。JENGA创新地将此机制集成到其定制设计的预测器中。具体来说,JENGA在训练期间跟踪中间激活元素的零频率,并定期修剪与最高零频率相关的神经元。无需依赖任何先验假设,弹性尺寸变换自适应地确定每个预测器的最优大小,有效减少了计算和内存开销。

段要点:全面的开销分析。我们最后分析了预测器在训练和推理过程中引入的开销。在离线训练中,主要瓶颈在于获取注意力分数的信息量$I(S_{attn})$。得益于分块方式,我们无缝地将我们的自定义训练核集成到最先进的FlashAttention【15, 16】中。这种集成避免了显式计算和存储完整的注意力分数。相反,$I(S_{attn})$是实时派生的,导致内存复杂度随序列长度线性增长。

在在线推理中,给定序列长度s、头维度h和块大小b,计算开销包括两部分:(1)预测$I(Q)$或$I(K)$的复杂度为$O(sh^2)$;(2)预测$I(S_{attn})$的复杂度为$O(s^2/b^2)$。在长上下文场景中,第二部分成为主导因素。然而,这个开销可以通过增加块大小b来有效缓解。在内存方面,主要开销来自预测器内的线性权重,复杂度为$O(bh^2)$。重要的是,这个复杂度相对于模型配置保持不变。最后,得益于弹性尺寸变换,计算和内存复杂度都因一个稀疏因子而降低,平均减少了50%。

4.4 高性能核优化

段要点:引言与挑战。JENGA专注于token级稀疏性,对原始微调动态的修改极小,从而能够无缝重用现有的优化计算流程。然而,JENGA中隐藏着两个关键挑战,影响其性能。首先,跨层稀疏模式的可变性需要迭代的token选择和填充,导致大量昂贵的全局内存移动。其次,LLM庞大的词汇量需要大量的激活内存来计算每个token的输出损失梯度,尤其是在长上下文场景中。JENGA在核级别集成了几种硬件高效的技术,有效缓解了这些瓶颈。

图9:朴素token移动与JENGA的比较。红线突出显示了朴素核产生的大量全局内存移动成本。JENGA通过核融合开发了一种无置换策略。
图9:朴素token移动与JENGA的比较。红线突出显示了朴素核产生的大量全局内存移动成本。JENGA通过核融合开发了一种无置换策略。

段要点:无置换的Token移动。上下文token稀疏性的动态特性导致不同层具有不同的稀疏模式,涉及不同的token子集。如图9所示,在每一层,一组信息量较少的token被消除,剩余的token被重新排列以作为注意力块的输入。然后,注意力输出被填充零以保持维度一致性,并最终与原始输入进行残差相加。token选择、token填充和残差相加的过程涉及大量的全局内存数据移动,这会产生高昂的内存访问延迟。

JENGA开发了一种无置换策略,将所有不必要的置换操作与注意力计算融合在一起。JENGA不是物化重新排列的token,而是直接从原始输入加载选定的token。此外,JENGA将注意力输出就地加到原始输入上,在一个步骤中同时完成token填充和残差相加。在反向传播期间,当计算注意力权重的梯度时,原始输入会被重新计算。由于输入嵌入矩阵可以通过从输出嵌入矩阵中减去自注意力输出来恢复,因此开销极小。这种简化的方法消除了不必要的内存分配和昂贵的全局内存移动,显著提升了系统性能。

图10:损失梯度计算期间的内存峰值,因词汇量大和上下文长而加剧。x轴表示一个微调周期的timeline。
图10:损失梯度计算期间的内存峰值,因词汇量大和上下文长而加剧。x轴表示一个微调周期的timeline。

图11:JENGA的两个可用扩展:(a) 二维稀疏性和 (b) 稀疏敏感的卸载。
图11:JENGA的两个可用扩展:(a) 二维稀疏性和 (b) 稀疏敏感的卸载。

段要点:基于分段的峰值削减。LLM通常是自回归的,预测给定所有先前token的下一个token的概率分布。在微调期间,对于输入序列中的每个token,模型生成下一个token的概率分布,并计算预测与真实值之间的损失。对于具有大词汇量和长上下文窗口的LLM,这个过程会导致激活内存使用的显著激增,如图10所示。虽然这些激活是暂时的,但由此产生的内存峰值提高了LLM微调的上限,对GPU内存资源提出了更严格的要求。现有实现【64, 65】通常使用列式并行将词汇表分布在多个GPU上。然而,这种优化仅在多GPU配置中有效,对单GPU设置没有好处。

为了解决这个问题,JENGA采用了一种基于分段的峰值削减策略,在最终损失计算期间将token序列划分为更小、可管理的段。JENGA不是对整个序列执行一次前向传播并保留所有中间激活,而是独立处理每个段,然后聚合它们的梯度。每个段的激活在相应的梯度计算完成后立即被丢弃。因此,当序列被划分为N个段时,激活内存峰值降低到1/N。这种方法极大地缓解了单个GPU上的内存压力,并与现有的多GPU优化兼容。

A7 补充细节

段要点:实现细节。我们用超过3000行Python和C++代码实现了JENGA。由于对原始微调动态的修改极小,JENGA与广泛的LLM架构兼容,无需任何代码更改。此外,JENGA可以无缝地与其他技术集成。

段要点:扩展1:二维稀疏性。如图11(a)所示,在JENGA中应用token级稀疏性之后,剩余的token可以进一步从现有的隐藏维度稀疏性技术中受益。这种跨两个维度的稀疏机制的自然结合,我们称之为二维稀疏性(2D-Sparsity),为模型的资源分配提供了更精细的控制,从而显著减少了激活内存和计算成本。

段要点:扩展2:稀疏敏感的卸载。JENGA通过将上下文token稀疏性纳入优化过程,增强了现有的基于卸载的技术【27, 30, 55, 56】。如图11(b)所示,我们开发了一种稀疏敏感的卸载策略,该策略能适应不同层之间变化的稀疏率。这种方法使得在CPU和GPU之间能够无缝传输更大量的数据,有效缓解了GPU内存的限制。

A4 实验环境

平台 GPU GPU 内存 NVLink CPU
A 8× A800 80 GB 400 GB/s Intel Xeon
B 8× A40 48 GB 112 GB/s Intel Xeon
C 4× 4090 24 GB N/A Intel Core

Table 4: 硬件平台配置。

家族 模型 参数 默认上下文窗口 (S)
OPT OPT-350M 350M 2K
OPT-1.3B 1.3B 2K
OPT-6.7B 6.7B 2K
Llama Llama2-7B 7B 4K
Llama2-13B 13B 4K
Llama3-8B 8B 8K

Table 5: 模型配置。

A4 实验结果

6.2 端到端性能

图12:在A800上的内存占用比较。
图12:在A800上的内存占用比较。

图13:JENGA在A800和A40上的端到端加速。
图13:JENGA在A800和A40上的端到端加速。

模型 方法 代码 常识 数学 摘要 问答
Llama2-7B-8k LoRA 19.33 53.68 12.37 25.13 41.52
JENGA 19.31 53.53 12.39 25.04 41.4

Table 6: 在LongBench基准上模型准确性的比较分析(越高越好)。

S 方法 PG PP
4K LoRA 16.51 2.12
JENGA 16.79 2.15
8K LoRA 15.34 2.05
JENGA 15.58 2.07
16K LoRA 14.19 1.96
JENGA 14.39 1.98

Table 7: PG19 (PG) 和 Proof-Pile (PP) 数据集的困惑度。

6.3 消融研究

图14:Llama2微调的性能分解:(a) 内存占用和 (b) 执行时间。
图14:Llama2微调的性能分解:(a) 内存占用和 (b) 执行时间。

图15:Llama2 7B(上)和OPT 6.7B(下)中各层的内存占用和相应阈值。
图15:Llama2 7B(上)和OPT 6.7B(下)中各层的内存占用和相应阈值。

图16:(a) 无重要token的块的比例和 (b) 一个token块内注意力分数的分布。
图16:(a) 无重要token的块的比例和 (b) 一个token块内注意力分数的分布。

图17:(a) 在LongAlign (LA) / RedPajama (RP) 上的训练损失曲线和 (b) 预测器的预测可视化。
图17:(a) 在LongAlign (LA) / RedPajama (RP) 上的训练损失曲线和 (b) 预测器的预测可视化。

0 5 10 15 20 25 30
0.84 0.84 0.84 0.84 0.84 0.84 0.84
0.28 0.30 0.31 0.30 0.29 0.28 0.27

Table 8: Llama2各层预测器的参数大小(百万),带或不带弹性尺寸变换。

图18:JENGA的无置换核的性能。
图18:JENGA的无置换核的性能。

图19:损失梯度计算中的内存使用峰值。x轴表示一个微调周期的timeline。
图19:损失梯度计算中的内存使用峰值。x轴表示一个微调周期的timeline。

6.4 扩展评估

图20:两个扩展带来的性能提升。
图20:两个扩展带来的性能提升。

6.5 可扩展性分析

JENGA在4个4090 GPU上表现出良好的强可扩展性,性能随GPU数量成比例增长,因为它最小化了token参与且没有引入额外的通信开销(图21)。

图21:JENGA的强可扩展性评估。
图21:JENGA的强可扩展性评估。

A5 结论

本文提出了JENGA,一个旨在优化LLM长上下文微调的高效系统。我们的方法在LLM长上下文微调中引入了一种新颖的稀疏机制,称为上下文Token稀疏性。为了系统地利用这一机制,我们开发了三项关键技术,分别用于识别、预测和利用这种稀疏性,实现了比现有最先进方法更优的内存节省和性能加速。压缩体现智能,而稀疏性是压缩的一种强大形式。我们期望JENGA能激励更广泛地探索稀疏性以推动LLM的发展。

A6 附录

摘要

段要点:附录内容概览。此工件(Artifact)伴随ATC'25论文《JENGA: Enhancing Long-Context Fine-tuning of LLMs with Contextual Token Sparsity》。它包含了复现实验结果(包括图表)所需的所有代码、脚本、数据集和预训练权重。该工件还提供了用于设置验证和快速结果演示的辅助脚本,并支持快速和全规模的实验复现工作流。

范围

段要点:可验证的声明。此工件允许用户验证论文中提出的所有经验性声明,包括所提JENGA方法在长上下文微调任务上的内存和运行时性能、预测准确性以及可扩展性。它可用于检查所有端到端结果、执行消融研究以及测试跨模型和数据集的泛化能力。

内容

段要点:工件组成。该工件包含以下组件:
* 模型和预测器权重 (checkpoints/): 预训练的模型和预测器权重。
* 数据集 (dataset/): 实验中使用的所有数据集。
* 日志文件 (logs/): 用于生成论文图表的原始日志。
*
输出图表 (output_figures/): 从日志复现的图表。
*
实验脚本 (scripts/): 用于复现论文中每个图表/表格的脚本。
*
源代码 (src/)**: JENGA的核心实现。

托管

段要点:资源获取。该工件托管在一个公共Github仓库中。所有资源,包括模型和数据集,可以从以下地址获取:
* 工件仓库: https://github.com/Pairshoe/Jenga-AE
* 模型权重: Models Link (peft_model.zip 和 predictor.zip)
* Hugging Face 模型: 详见表9。
* 数据集: Datasets Link (dataset.zip)

模型名称 Hugging Face 仓库 ID 目标目录
Llama2-7b-hf meta-llama/Llama-2-7b-hf checkpoints/base_models/Llama-2-7b-hf/
Llama2-13b-hf meta-llama/Llama-2-13b-hf checkpoints/base_models/Llama-2-13b-hf/
Llama3-8b-hf meta-llama/Llama-3-8b-hf checkpoints/base_models/Llama-3-8b-hf/
opt-350m facebook/opt-350m checkpoints/base_models/opt-350m/
opt-1.3b facebook/opt-1.3b checkpoints/base_models/opt-1.3b/
opt-6.7b facebook/opt-6.7b checkpoints/base_models/opt-6.7b/

Table 9: 所需基础模型的下载链接及其各自的目标目录。

要求

段要点:软硬件要求
* 软件: Python 3.10, PyTorch, Flash Attention (需单独安装)。
* 硬件 (完整实验):
* 最低配置: 1块 NVIDIA A800 或 A40 GPU。
* 可扩展性实验: 4块 NVIDIA 4090 GPU。

A.1 安装

段要点:安装步骤
1. 安装依赖: pip install -r requirements.txt
2. 安装 Flash Attention: pip install flash-attn --no-build-isolation
3. 从源码安装 Jenga: pip install -e .

A.2 实验工作流

段要点:环境设置验证
运行脚本以验证设置是否正确: bash hello_world.sh

段要点:快速复现
运行脚本从原始数据绘制图表: bash RUNME-a.sh

段要点:深度复现
* 一步执行: 一次性复现所有结果 (运行时约5小时)。

bash RUNME-b-a800.sh # for A800
bash RUNME-b-a40.sh # for A40
bash RUNME-b-4x4090.sh # for 4x4090
输出文件夹 复现脚本
output_figures/fig12 scripts/rep_fig12.sh
output_figures/fig13_a800 scripts/rep_fig13_a800.sh
output_figures/fig13_a40 scripts/rep_fig13_a40.sh
output_figures/fig14 scripts/rep_fig14.sh
output_figures/fig15 scripts/rep_fig15.sh
output_figures/fig18 scripts/rep_fig18.sh
output_figures/fig19 scripts/rep_fig19.sh
output_figures/fig20 scripts/rep_fig20.sh
output_figures/fig21 scripts/rep_fig21.sh

Table 10: 用于复现结果的生成输出文件夹和相应的脚本。