作者/机构: Albert Gu (卡内基梅隆大学机器学习系), Tri Dao (普林斯顿大学计算机科学系)
本文旨在解决现有主流基础模型(Foundation Models)骨干架构 Transformer 在处理长序列时计算效率低下的问题。Transformer 的核心注意力机制虽然效果强大,但其在上下文窗口内的密集信息路由能力导致了两个根本性缺陷:无法对窗口外的内容进行建模,以及计算复杂度随序列长度呈二次方增长。尽管已有许多亚二次方时间复杂度的架构(如线性注意力、门控卷积、循环模型和结构化状态空间模型SSM)被提出,但它们在语言等重要模态上的性能未能超越注意力机制。
核心问题: 现有高效的序列模型(如SSM)缺乏一种关键能力——基于内容进行推理(content-based reasoning),这限制了它们在信息密集型数据(如文本)上的表现。
研究目标: 提出一种新的序列模型架构,既能保持 Transformer 的强大建模能力,又能实现随序列长度线性扩展的计算效率,从而成为适用于各类序列数据的通用基础模型骨干。
创新点/主要贡献:
选择机制 (Selection Mechanism):
硬件感知的高效算法 (Hardware-aware Algorithm):
Mamba 架构 (Architecture):
Mamba的核心特性:
实证成果: Mamba 作为一种通用的序列模型骨干,在多种模T态上均取得了最先进的性能。
- 合成任务: 轻松解决了如选择性复制(Selective Copying)和归纳头(Induction Heads)等对大语言模型至关重要的任务,并能无限外推(超过1M词元)。
- 音频和基因组学: 在音频波形和DNA序列建模方面,Mamba在预训练质量和下游任务指标上均优于SaShiMi、Hyena和Transformers等先前SOTA模型。
- 语言建模: Mamba是首个真正达到Transformer级别性能的线性时间序列模型。Mamba-3B模型在预训练和下游评估中均优于同等规模的Transformers,并与两倍于其规模的Transformers性能相当。
结构化状态空间序列模型(S4) 是一类新兴的深度学习序列模型,与RNN、CNN和经典状态空间模型广泛相关。它们受到一个特定连续系统的启发,该系统通过一个隐式潜在状态 $h(t) \in \mathbb{R}^N$ 将一维函数或序列 $u(t) \in \mathbb{R}$ 映射到 $y(t) \in \mathbb{R}$。
S4模型由四个参数(Δ, A, B, C)定义,通过两个阶段完成序列到序列的转换。
离散化 (Discretization):第一阶段通过固定的公式 $\bar{\mathbf{A}} = f_A(\Delta, \mathbf{A})$ 和 $\bar{\mathbf{B}} = f_B(\Delta, \mathbf{A}, \mathbf{B})$ 将“连续参数”(Δ, A, B)转换为“离散参数”($\bar{\mathbf{A}}, \bar{\mathbf{B}}$)。这对 $(f_A, f_B)$ 称为离散化规则,例如零阶保持(ZOH)规则如公式(4)所示。从机械角度看,离散化可视为SSM前向传播计算图的第一步。
计算 (Computation):参数从 (Δ, A, B, C) 转换为 ($\bar{\mathbf{A}}, \bar{\mathbf{B}}, \mathbf{C}$) 后,模型可以通过两种方式计算:线性循环(公式2)或全局卷积(公式3)。通常,模型在训练时使用高效并行的卷积模式,在自回归推理时切换到高效的循环模式。
线性时间不变性(LTI) 是公式(1)到(3)的一个重要特性,即模型的动态性在时间上是恒定的。换句话说,(Δ, A, B, C) 以及因此得到的 ($\bar{\mathbf{A}}, \bar{\mathbf{B}}$) 对所有时间步都是固定的。LTI特性与循环和卷积紧密相关。迄今为止,所有结构化SSM都是LTI模型,这是因为存在根本的效率约束。然而,本文的一个核心洞见是LTI模型在建模某些类型的数据时存在根本局限,本文的技术贡献在于移除LTI约束的同时克服效率瓶颈。
结构和维度 结构化SSM为了高效计算,需要对 A 矩阵施加结构,最流行的结构是对角结构,本文也采用此结构。在这种情况下,$\mathbf{A} \in \mathbb{R}^{N \times N}, \mathbf{B} \in \mathbb{R}^{N \times 1}, \mathbf{C} \in \mathbb{R}^{1 \times N}$ 矩阵都可以用 $N$ 个数字表示。对于一个批次大小为 $B$、长度为 $L$、通道数为 $D$ 的输入序列 $x$,SSM会独立应用于每个通道。此时,总的隐藏状态维度为 $D \times N$,在整个序列上计算它需要 $O(BDLN)$ 的时间和内存,这是本文在第3.3节中解决的根本效率瓶颈的根源。
通用状态空间模型 “状态空间模型”一词含义广泛,泛指任何具有潜在状态的循环过程。本文中,“SSM”特指结构化SSM或S4模型及其衍生物。
SSM架构 SSM是可集成到端到端神经网络架构中的独立序列变换模块。一些知名的SSM架构包括:
- 线性注意力 (Linear attention):可视为一种退化的线性SSM。
- H3:可以看作是在一个SSM两侧夹了两个门控连接的架构。
- Hyena: 与H3架构相同,但用一个MLP参数化的全局卷积替代了S4层。
- RetNet: 在架构中增加了一个额外的门,并使用一个更简单的SSM,允许一种替代的并行计算路径。
- RWKV: 一种基于另一种线性注意力近似的现代RNN,其核心“WKV”机制涉及LTI循环,可看作是两个SSM的比值。
序列建模的根本问题是将上下文压缩成一个更小的状态。 不同的序列模型在此问题上做出了不同的权衡。注意力机制之所以有效但低效,是因为它完全不压缩上下文(自回归推理需要存储完整的KV缓存)。相反,循环模型之所以高效,是因为它们具有有限的状态,但其效果受限于这个状态对上下文的压缩程度。
为了理解这一原则,本文关注两个合成任务示例。
- 选择性复制 (Selective Copying) 任务通过改变需要记忆的词元位置,要求模型具备内容感知能力,以记忆相关词元(彩色)并过滤掉不相关词元(白色)。
- 归纳头 (Induction Heads) 任务是一个被假设为解释大语言模型大部分上下文学习能力的关键机制。它要求模型具备上下文感知能力,以便在适当的上下文中产生正确的输出。
这些任务揭示了LTI模型的失败模式。 从循环角度看,它们恒定的动态性(如公式2中的($\bar{\mathbf{A}}, \bar{\mathbf{B}}$)转移)无法让它们从上下文中选择正确的信息,或以输入依赖的方式影响沿序列传递的隐藏状态。从卷积角度看,全局卷积虽然能解决常规复制任务(因为它只需时间感知),但在选择性复制任务上遇到困难,因为缺乏内容感知能力,无法处理输入输出之间变化的间距。
总结来说,序列模型的效率与效果权衡取决于它们压缩状态的好坏。 高效模型必须有小状态,而有效模型必须有包含所有必要上下文信息的状态。因此,本文提出构建序列模型的一个基本原则是选择性(selectivity):即具备上下文感知能力,以关注或过滤输入,并将其整合到一个序列状态中。选择机制控制信息如何沿序列维度传播或交互。
将参数变为输入依赖。在模型中引入选择机制的一种方法是,让那些影响序列交互的参数(例如RNN的循环动态或CNN的卷积核)依赖于输入。
算法1和2展示了本文使用的主要选择机制。核心区别在于,将几个参数 Δ, B, C 变为输入的函数,并对张量形状进行相应调整。特别地,这些参数现在有了一个长度维度 $L$,意味着模型从时间不变变成了时间变化。这使得模型不再等价于卷积(公式3),对其效率产生了影响,将在下一节讨论。
参数化的具体选择。本文具体选择 $\mathbf{B}(x_t) = \text{Linear}_N(x_t)$, $\mathbf{C}(x_t) = \text{Linear}_N(x_t)$, $\tau_\Delta(x_t) = \text{Broadcast}_D(\text{Linear}_1(x_t))$, 并且 $\Delta_t = s_\Delta(\tau_\Delta(x_t))$,其中 $s_\Delta = \text{softplus}$,$\text{Linear}_d$ 是一个参数化的投影,将输入投影到维度 $d$。选择 $s_\Delta$ 和 $\tau_\Delta$ 是基于与RNN门控机制的联系,将在3.5节中解释。
先前模型的动机回顾。我们首先回顾先前方法的动机,并概述我们如何克服它们的局限性。
- 从高层次看,像SSM这样的循环模型总是在表达能力和速度之间进行权衡:如3.1节所述,具有更大隐藏状态维度的模型应该更有效但更慢。因此,我们的目标是在不牺牲速度和内存成本的情况下最大化隐藏状态维度。
- 循环模式比卷积模式更灵活,因为后者(公式3)是通过展开前者(公式2)得到的。然而,这需要计算并物化形状为 (B, L, D, N) 的潜在状态 $h$,这个状态比输入 $x$ 和输出 $y$ 的形状 (B, L, D) 大得多(大了 $N$ 倍,即SSM状态维度)。因此,引入了更高效的卷积模式,它可以绕过状态计算,只物化一个大小仅为 (B, L, D) 的卷积核(公式3a)。
- 先前的LTI状态空间模型利用循环-卷积的双重形式,将有效状态维度增加了 $N$ 倍(约10-100),远大于传统RNN,且没有效率损失。
选择性扫描概述:硬件感知的状态扩展。选择机制旨在克服LTI模型的局限性;同时,我们因此需要重新审视SSM的计算问题。我们通过三种经典技术来解决这个问题:核函数融合(kernel fusion)、并行扫描(parallel scan)和重计算(recomputation)。我们有两个主要观察:
- 朴素的循环计算使用 $O(BDLN)$ FLOPs,而卷积计算使用 $O(BDL \log(L))$ FLOPs,前者的常数因子更低。因此,对于长序列和不太大的状态维度 $N$,循环模式实际上可能使用更少的FLOPs。
- 两个挑战是循环的顺序性和大的内存使用。为了解决后者,就像卷积模式一样,我们可以尝试不实际物化完整的状态 $h$。
核心思想是利用现代加速器(GPU)的特性,只在内存层级结构中更高效的层次上物化状态 $h$。具体来说,大多数操作(矩阵乘法除外)都受内存带宽的限制。这包括我们的扫描操作,我们使用核函数融合来减少内存IO的数量,从而与标准实现相比,速度得到显著提升。具体地,我们不是在GPU HBM(高带宽内存)中准备大小为 (B, L, D, N) 的扫描输入($\bar{\mathbf{A}}, \bar{\mathbf{B}}$),而是将SSM参数(Δ, A, B, C)直接从慢速HBM加载到快速SRAM中,在SRAM中执行离散化和循环,然后将大小为 (B, L, D) 的最终输出写回HBM。
为避免顺序循环,我们观察到尽管它不是线性的,但仍可以用一种工作高效的并行扫描算法来并行化【10, Blelloch 1990; 71, Martin and Cundy 2018; 98, Smith, Warrington, and Linderman 2023】。
最后,我们必须避免保存中间状态,这些状态是反向传播所必需的。我们谨慎地应用经典的重计算技术来减少内存需求:中间状态不被存储,而是在反向传播中当输入从HBM加载到SRAM时重新计算。因此,融合的选择性扫描层具有与使用FlashAttention的优化版Transformer实现相同的内存需求。融合核和重计算的细节在附录D中。完整的选择性SSM层和算法如图1所示。
与结构化SSM一样,选择性SSM是独立的序列变换,可以灵活地集成到神经网络中。H3架构是大多数知名SSM架构的基础,通常由一个受线性注意力启发的模块和一个MLP(多层感知器)模块交错组成。我们通过将这两个组件合并为一个,并进行同质化堆叠,简化了该架构(图3)。这一设计的灵感来自于门控注意力单元(GAU)【53, Hua et al. 2022】,它对注意力机制做了类似的事情。
该架构通过一个可控的扩展因子 E 来扩展模型维度 D。对于每个块,大部分参数($3ED^2$)位于线性投影中($2ED^2$用于输入投影,$ED^2$用于输出投影),而内部SSM贡献的参数较少。SSM的参数数量(用于Δ, B, C的投影,以及矩阵A)相对要小得多。我们重复这个块,并与标准的归一化和残差连接交错,以构成Mamba架构。我们总是在实验中固定 $E=2$,并使用两个堆叠的块来匹配Transformer中交错的MHA(多头注意力)和MLP块的$12D^2$参数。我们使用SiLU / Swish激活函数,这样门控MLP就变成了流行的“SwiGLU”变体【16, Chowdhery et al. 2023; 22, Dauphin et al. 2017; 95, Shazeer 2020; 105, Touvron et al. 2023】。最后,我们还额外使用了一个可选的归一化层(我们选择LayerNorm【4, J. L. Ba, Kiros, and Hinton 2016】),这是受到RetNet在类似位置使用归一化层的启发【100, Y. Sun et al. 2023】。
选择机制是一个更广泛的概念,可以以不同方式应用,例如应用于更传统的RNN或CNN,应用于不同的参数(例如算法2中的 A),或使用不同的变换 $s(x)$。
我们强调最重要的联系:RNN的经典门控机制是SSM选择机制的一个实例。RNN门控与连续时间系统离散化之间的联系是公认的【32, Funahashi and Nakamura 1993; 102, Tallec and Ollivier 2018】。实际上,定理1是对【40, Gu, Johnson, Goel, et al. 2021, Lemma 3.1】的改进,推广到了ZOH离散化和输入依赖的门(证明在附录C)。更广泛地看,SSM中的Δ可以被视为扮演了RNN门控机制的广义角色。与先前工作一致,我们认为SSM的离散化是启发式门控机制的原则性基础。
定理1。当 $N=1, A=-1, C=1, \tau_\Delta = \text{Linear}(x)$, 且 $s_\Delta = \text{softplus}$ 时,选择性SSM循环(算法2)的形式为:
如3.2节所述,我们对 $\tau_\Delta, s_\Delta$ 的具体选择正是源于这种联系。特别地,请注意,如果某个给定的输入 $x_t$ 应该被完全忽略(如在合成任务中必需的那样),所有 $D$ 个通道都应该忽略它,因此我们在通过广播操作重复Δ之前,将输入投影到1维。
我们阐述了选择机制的三个具体机理效果。
- 可变间距 (Variable Spacing)。选择性允许过滤掉可能出现在感兴趣输入之间的不相关噪声词元。这在选择性复制任务中得到了体现,但在常见的数据模态中普遍存在,尤其是在离散数据中——例如语言中“嗯”之类的填充词。这个特性之所以出现,是因为模型可以机械地过滤掉任何特定的输入 $x_t$,例如在门控RNN的情况下(定理1),当 $\bar{\mathbf{B}}_t \to 0$ 时。
- 过滤上下文 (Filtering Context)。经验观察表明,许多序列模型并不会随着上下文的增长而改善【96, F. Shi et al. 2023】,尽管原则上更长的上下文应该带来更好的性能。一种解释是,许多序列模型在必要时无法有效忽略不相关的上下文;一个直观的例子是全局卷积(以及一般的LTI模型)。另一方面,选择性模型可以随时重置其状态以移除多余的历史,因此其性能原则上随上下文长度单调提高(例如4.3.2节)。
- 边界重置 (Boundary Resetting)。在将多个独立序列拼接在一起的场景中,Transformer可以通过实例化特定的注意力掩码来保持它们的分离,而LTI模型则会在序列之间泄露信息。选择性SSM也可以在边界处重置其状态(例如 $\Delta_t \to \infty$,或者定理1中当 $\bar{\mathbf{A}}_t \to 1$ 时)。这些场景可能人为出现(例如,为了提高硬件利用率而将文档打包在一起)或自然发生(例如,强化学习中的情节边界【68, Lu et al. 2023】)。
此外,我们详细阐述了每个选择性参数的效果。
- Δ的诠释。总的来说,Δ 控制着对当前输入 $x_t$ 的关注或忽略程度。它推广了RNN的门控(例如定理1中的 $\bar{\mathbf{B}}_t$):从机理上讲,大的Δ会重置状态 $h$ 并专注于当前输入 $x$,而小的Δ会保持状态并忽略当前输入。
- A的诠释。我们注意到,虽然 A 参数也可以是选择性的,但它最终只通过与Δ的相互作用影响模型,即 $\bar{\mathbf{A}} = \exp(\Delta \mathbf{A})$(离散化公式4)。因此,Δ的选择性足以确保($\bar{\mathbf{A}}, \bar{\mathbf{B}}$)的选择性,并且是性能提升的主要来源。我们假设,在Δ之外(或替代Δ)使A具有选择性会产生类似的性能,为简单起见我们将其省略。
- B和C的诠释。如3.1节所述,选择性最重要的特性是过滤掉不相关的信息,以便将序列模型的上下文压缩到一个高效的状态中。在SSM中,将B和C修改为选择性的,允许更精细地控制是否让输入 $x_t$ 进入状态 $h_t$,或让状态进入输出 $y_t$。这可以被解释为允许模型根据内容(输入)和上下文(隐藏状态)分别调节循环动态。
实数与复数。大多数先前的SSM在其状态 $h$ 中使用复数,这对于在感知模态中许多任务上取得强劲性能是必要的【37, Gu, Goel, and Ré 2022】。然而,经验观察表明,在某些情况下,完全实值的SSM似乎效果不错,甚至可能更好【70, Ma et al. 2023】。我们默认使用实数值,这在我们除了一项任务外的所有任务中都表现良好;我们假设复数-实数的权衡与数据模态中的连续-离散谱有关,其中复数对连续模态(如音频、视频)有帮助,但对离散模态(如文本、DNA)则不然。
初始化。大多数先前的SSM也建议特殊的初始化,特别是在复数值情况下,这在几种情况下(如低数据量场景)有所帮助。我们对于复数情况的默认初始化是S4D-Lin,对于实数情况是S4D-Real【39, Gu, Gupta, et al. 2022】,这基于HIPPO理论【36, Gu, Dao, et al. 2020】。然而,我们预计许多初始化方法都能工作得很好,特别是在大数据和实值SSM的场景中;一些消融研究在4.6节中进行了考虑。
Δ的参数化。我们将对Δ的选择性调整定义为 $\tau_\Delta(x) = \text{Broadcast}_N(\text{Linear}_1(x))$,这是由Δ的机理所驱动的(3.5节)。我们观察到,这可以从维度1推广到更大的维度R。我们将其设置为D的一小部分,与块中的主要线性投影相比,使用的参数数量可以忽略不计。我们还注意到,广播操作可以被看作是另一个线性投影,初始化为特定的1和0模式;如果这个投影是可训练的,这将导致另一种选择 $\tau_\Delta(x) = \text{Linear}_R(\text{Linear}_R(x))$,可以看作是一个低秩投影。在我们的实验中,Δ参数(可看作一个偏置项)被初始化为 $s_\Delta^{-1}(\text{Uniform}([0.001, 0.1]))$,遵循了先前关于SSM的工作【41, Gu, Johnson, Timalsina, et al. 2023】。
备注3.1。为了在我们的实验结果中保持简洁,我们有时将选择性SSM缩写为S6模型,因为它们是具有选择机制并通过扫描计算的S4模型。
https://github.com/state-spaces/mamba开源。| 架构 | 内部层 | 准确率(%) |
|---|---|---|
| GatedMLP | S4 (LTI) | 2.0 |
| H3 | S4 (LTI) | 3.1 |
| Mamba | S4 (LTI) | 3.1 |
| GatedMLP | S6 (选择性) | 99.9 |
| H3 | S6 (选择性) | 100.0 |
| Mamba | S6 (选择性) | 100.0 |
| 表1: (选择性复制。) 架构与内部序列层组合的准确率。 |
| 模型 | L=64 | L=256 (训练) | ... | L=1,048,576 |
|---|---|---|---|---|
| MHA-Abs | ✓ | 99.8 | ... | ✗ |
| MHA-RoPE | ✓ | 99.9 | ... | ✗ |
| MHA-xPos | ✓ | 99.9 | ... | ✗ |
| H3 (S4) | ✓ | 99.8 | ... | 50.3 |
| Hyena | ✓ | 99.8 | ... | 50.2 |
| Mamba (S6) | ✓ | 100.0 | ... | 100.0 |
| 表2: (归纳头。) 模型在序列长度256上训练,并在64到1,048,576的序列长度上测试。 |
本文讨论了相关工作、局限性和一些未来方向。
无免费午餐:连续-离散谱。结构化SSM最初被定义为连续系统的离散化,对连续时间数据模态(如音频、视频)有很强的归纳偏置。如第3.1和3.5节所讨论,选择机制克服了它们在离散模态(如文本和DNA)上的弱点;但这反过来可能会影响它们在LTI SSM擅长的数据上的性能。本文对音频波形的消融实验更详细地研究了这种权衡。
下游可供性(Downstream Affordances)。基于Transformer的基础模型(特别是LLM)拥有丰富的生态系统,包括微调、适配、提示、上下文学习、指令调优、RLHF、量化等多种交互模式。一个特别令人感兴趣的问题是,像SSM这样的Transformer替代品是否具有类似的属性和可供性。
扩展性(Scaling)。本文的实证评估仅限于较小的模型尺寸,低于大多数强大的开源LLM(如Llama)以及其他循环模型(如RWKV和RetNet)的规模,后者已在7B参数规模及以上进行了评估。Mamba在这些更大规模下是否仍具竞争力尚待评估。我们还注意到,扩展SSM可能涉及本文未讨论的进一步工程挑战和模型调整。
本文为结构化状态空间模型引入了一种选择机制,使其能够在序列长度上线性扩展的同时执行上下文相关的推理。当集成到一个简单的无注意力架构中时,Mamba在多种领域取得了最先进的结果,其性能与强大的Transformer模型相当或更优。我们对选择性状态空间模型在为不同领域构建基础模型,特别是在需要长上下文的新兴模态(如基因组学、音频和视频)中的广泛应用感到兴奋。我们的结果表明,Mamba是成为通用序列模型骨干的有力竞争者。
本文的选择机制受到并与门控、超网络和数据依赖等概念相关。 但我们认为它是一个值得澄清的独特概念。
- 门控(Gating): 最初指LSTM和GRU等RNN的门控机制,控制信息如何随时间传播并沿序列长度维度交互。现在,“门控”一词的用法已泛化为任何乘法交互,即使不沿序列维度交互,这与原始RNN的意义非常不同。
- 超网络(Hypernetworks)和数据依赖(Data-dependence): 这些概念非常宽泛,几乎可以涵盖任何参数依赖于数据的模型,包括注意力机制和简单的GLU激活函数,因此信息量不大。
- 选择(Selection): 我们认为选择机制与传统RNN的门控机制关系最密切(定理1证明了后者是前者的一个特例)。我们使用“选择”一词来特指模型选择或忽略输入并促进数据沿序列长度维度交互的机理作用。
定理1的证明。考虑一个选择性SSM(算法2),其中 $N=1, A=-1, C=1, \tau_\Delta = \text{Linear}(x), s_\Delta = \text{softplus}$。对应的连续时间SSM(公式1)是一个漏积分器(leaky integrator)。离散化步长为 $\Delta_t = \text{softplus}(\mathbf{v} \cdot \mathbf{x}_t + b)$。应用零阶保持(ZOH)离散化公式:
$\bar{\mathbf{A}}_t = e^{\Delta_t A} = e^{-\Delta_t}$
$\bar{\mathbf{B}}_t = (\Delta_t A)^{-1}(e^{\Delta_t A} - I) B = (-\Delta_t)^{-1}(e^{-\Delta_t} - 1) \cdot 1 = \frac{1 - e^{-\Delta_t}}{\Delta_t}$
最终的离散循环(公式2a)为 $h_t = e^{-\Delta_t} h_{t-1} + \frac{1 - e^{-\Delta_t}}{\Delta_t} x_t$,这与门控RNN的形式相符。
在没有输入依赖的选择性时,SSM可以高效地实现为卷积。有了选择性,SSM不再等价于卷积,但我们可以利用并行关联扫描。我们通过核函数融合和重计算使SSM扫描在现代硬件(GPU)上既快速又内存高效。
- 速度:
- 标准实现方式需要在慢速的GPU HBM(高带宽内存)中物化大小为(B, L, D, N)的扫描输入,导致大量内存IO。
- 我们的融合核将离散化、扫描和与C的乘法合并为一个操作:
1. 从HBM读取SSM参数(Δ, A, B, C)到快速SRAM。
2. 在SRAM中离散化得到$\bar{\mathbf{A}}, \bar{\mathbf{B}}$。
3. 执行并行扫描,中间状态仍在SRAM中。
4. 与C相乘求和,得到最终输出并写回HBM。
- 这种方式将IO减少了 $O(N)$ 倍,实际操作中速度提升20-40倍。对于过长的序列,可分块处理。
- 内存:
- 为了避免内存爆炸,我们不在前向传播中保存大小为(B, L, D, N)的中间状态。
- 在反向传播中,我们使用重计算技术,当需要时重新计算这些中间状态。这不仅节省了存储空间,还因为避免了从HBM读取大量中间状态而加快了反向传播的速度。
- 通过对整个选择性SSM块(包括输入投影、激活、扫描等)进行重计算优化,其总内存需求与使用FlashAttention的优化版Transformer实现相当。