作者: Jimmy T.H. Smith*, 1, 2, Andrew Warrington*, 2, 3, Scott W. Linderman2, 3
*同等贡献。
1 斯坦福大学,计算与数学工程研究所。
2 斯坦福大学,吴 Tsai 神经科学研究所。
3 斯坦福大学,统计系。
{jsmith14,awarring,scott.linderman}@stanford.edu.
本文旨在解决机器学习中高效建模长序列的挑战性问题,即序列中相隔数千个时间步的观测值可能共同编码了解决任务的关键信息。尽管已有利普希茨循环神经网络(RNN)、卷积神经网络(CNN)和高效 Transformer 等多种方法,但它们在处理极长序列任务时仍表现不佳。Gu等人(2021a)提出的结构化状态空间序列(S4)层通过结合线性状态空间模型(SSM)、HiPPO框架和深度学习,在长程序列建模任务上取得了显著的性能提升。
本文在此基础上,引入了一个新的状态空间层——S5层,它在两个主要方面简化并改进了S4层:
1. 从多SISO到单MIMO结构:S4层使用一个包含许多独立的单输入单输出(SISO)SSM的集合,而S5层则用一个多输入多输出(MIMO)SSM替代之。
2. 采用并行扫描替代卷积:S4层依赖于卷积和频域方法来高效处理序列,这需要一个非平凡的卷积核计算过程。相比之下,S5层使用了一种高效且被广泛实现的并行扫描(parallel scan)算法,使其能够完全在时域内以循环方式进行计算,从而省去了计算卷积核的需要。
本文建立了S4和S5之间的数学联系,并利用这一联系为S5模型开发了有效的初始化和参数化方案。研究发现,虽然S4使用的特定HiPPO-LegS矩阵无法进行数值稳定的对角化以用于S5,但其对角近似(即HiPPO-N矩阵)能够取得相当的性能,这与最近的DSS和S4D层的工作一致。本文将Gu等人(2022)的一个理论结果扩展到MIMO设置,为在S5中使用对角近似提供了理论依据。
最终设计的S5层具有多个优点:
* 计算效率:其计算复杂度与S4相当,在序列长度上呈线性关系。
* 实现简单:实现方式直接明了(如附录A所示)。
* 灵活性:能够高效处理时变SSM和不规则采样的观测数据,而这对于S4的卷积实现是难以处理的。
* 卓越性能:在多个长程序列建模任务上取得了最先进的性能。在长程竞技场(LRA)基准测试中平均准确率达到87.4%,在其中最困难的Path-X任务上达到了98.5%的准确率。
连续时间线性SSM。作为S4和S5层的核心组件,连续时间线性SSM由一个微分方程定义。给定输入信号$u(t) \in R^U$,潜状态$x(t) \in R^P$和输出信号$y(t) \in R^M$,一个线性连续时间SSM由以下方程定义:
该模型由状态矩阵$A \in R^{P \times P}$、输入矩阵$B \in R^{P \times U}$、输出矩阵$C \in R^{M \times P}$和直通矩阵$D \in R^{M \times U}$参数化。
离散化。对于一个固定的步长$\Delta$,可以使用欧拉法、双线性变换法或零阶保持(ZOH)等方法对SSM进行离散化,从而定义一个线性递推关系:
其中,离散时间参数是连续时间参数的函数,具体形式由离散化方法决定。更多关于离散化方法的信息可参见【28, A first course in the numerical analysis of differential equations by Arieh Iserles, 2009】。
并行扫描操作。我们使用并行扫描来高效计算离散化线性SSM的状态。给定一个二元结合运算符•(即(a • b) • c = a • (b • c))和一个包含L个元素的序列$[a_1, a_2, ..., a_L]$,扫描操作(有时称为全前缀和)返回序列:
并行化线性递推。计算一个长度为L的离散化SSM的线性递推$x_k = \bar{A}x_{k-1} + \bar{B}u_k$(如公式2所示)是扫描操作的一个具体实例。根据【6, Prefix sums and their applications by Guy Blelloch, 1990】第1.4节的讨论,假设有L个处理器,并行化离散SSM中潜状态转移的线性递推可以在$O(T_M \log L)$的并行时间内完成,其中$T_M$表示矩阵-矩阵乘法的成本。对于一个通用矩阵$A \in R^{P \times P}$,$T_M$为$O(P^3)$,这在深度学习中可能成本过高。然而,如果$A$是一个对角矩阵,并行时间将变为$O(P \log L)$,且仅需$O(PL)$的空间。值得注意的是,高效的并行扫描是以工作高效的方式实现的,因此使用对角矩阵的并行扫描总计算成本为$O(PL)$次操作。更多关于并行扫描的信息请参见附录H。
S4层结构。S4层(【20, Efficiently modeling long sequences with structured state spaces by Albert Gu, Karan Goel, and Christopher Re, 2021, ICLR】)定义了一个非线性的序列到序列变换,将输入序列$u_{1:L} \in R^{L \times H}$映射到输出序列$u'_{1:L} \in R^{L \times H}$。一个S4层包含一个由H个独立的单输入单输出(SISO)SSM组成的集合,每个SSM具有N维状态。每个S4 SSM应用于输入序列的一个维度,从而实现从每个输入通道到每个预激活通道的独立线性变换。然后,对预激活值应用一个非线性激活函数。最后,应用一个逐位置的线性混合层来组合这些独立的特征并产生输出序列$u'_{1:L}$。附录中的图4a展示了将S4层视为独立SSM集合的视图。图2a则展示了S4的另一种视图,即一个具有大小为HN的状态和块对角状态、输入及输出矩阵的大型SSM。
HiPPO初始化与DPLR参数化。每个S4 SSM都利用HiPPO框架(【18, HiPPO: Recurrent memory with optimal polynomial projections by Albert Gu et al., 2020, NeurIPS】)进行在线函数逼近,通过使用一个HiPPO矩阵(通常是HiPPO-LegS矩阵)来初始化状态矩阵。实验证明,这种方法能带来强大的性能(【21, Combining recurrent, convolutional, and continuous-time models with linear state space layers by Albert Gu et al., 2021, NeurIPS】; 【20, Efficiently modeling long sequences with structured state spaces by Albert Gu, Karan Goel, and Christopher Re, 2021, ICLR】),并且可以被证明是在一个无限长的指数衰减度量下近似长程依赖(【23, How to train your HIPPO: State space models with generalized orthogonal basis projections by Albert Gu et al., 2023, ICLR】)。虽然HiPPO-LegS矩阵不能被稳定地对角化(【20, Efficiently modeling long sequences with structured state spaces by Albert Gu, Karan Goel, and Christopher Re, 2021, ICLR】),但它可以表示为一个正规加低秩(NPLR)矩阵。其正规部分,被称为HiPPO-N并记为$A_{NormalLegS}$,是可以对角化的。因此,HiPPO-LegS可以通过共轭变换转化为对角加低秩(DPLR)形式,S4利用这种形式来推导出一个高效的卷积核计算方法。这也推动了S4的DPLR参数化。
S4层的双重实现模式。高效应用S4层需要两种独立的实现方式:循环模式和卷积模式。对于在线生成任务,SSM像其他RNN一样以循环方式迭代。然而,当整个序列可用且观测值是均匀间隔时,会使用更高效的卷积模式。该模式利用了线性递推可以表示为输入与每个SSM的卷积核之间的一维卷积。然后可以应用快速傅里叶变换(FFT)来高效地并行化这一过程。附录中的图4a展示了S4层用于离线处理的卷积方法。值得注意的是,虽然原则上并行扫描可以使循环方法用于离线场景,但将并行扫描应用于所有H个N维SSM通常比S4实际使用的卷积方法昂贵得多。
S4层的可训练参数。每个S4层的可训练参数包括H个独立的、可学习的SSM参数副本以及混合层的$O(H^2)$个参数。对于每个$h \in \{1, ..., H\}$的S4 SSM,给定一个标量输入信号$u^{(h)}(t) \in R$,S4 SSM使用一个输入矩阵$B^{(h)} \in C^{N \times 1}$,一个DPLR参数化的转移矩阵$A^{(h)} \in C^{N \times N}$,一个输出矩阵$C^{(h)} \in C^{1 \times N}$,以及一个直通矩阵$D^{(h)} \in R^{1 \times 1}$,来产生一个信号$y^{(h)}(t) \in R$。为了将S4 SSM应用于离散序列,每个连续时间SSM都使用一个常数时间尺度参数$\Delta^{(h)} \in R^+$进行离散化。每个SSM的可学习参数是时间尺度参数$\Delta^{(h)} \in R^+$,连续时间参数$B^{(h)}$, $C^{(h)}$, $D^{(h)}$,以及由向量$\Lambda^{(h)} \in C^N$和$p^{(h)}, q^{(h)} \in C^N$参数化的DPLR矩阵(分别代表对角矩阵和低秩项)。为简化符号,我们将离散时间点k的S4 SSM状态的拼接表示为$x_k^{(1:H)} = [(x_k^{(1)})^T, ..., (x_k^{(H)})^T]^T$,H个SSM的输出表示为$y_k = [y_k^{(1)}, ..., y_k^{(H)}]^T$。
本节将介绍S5层,详细描述其结构、参数化和计算方式,并特别关注这些方面与S4层的区别。
采用单一MIMO SSM。S5层用一个多输入多输出(MIMO)SSM替换了S4中的SISO SSM集合(或大型块对角系统),如公式(1)所示。该MIMO SSM具有潜状态大小P,输入和输出维度为H。此MIMO SSM的离散化版本可以应用于一个向量值输入序列$u_{1:L} \in R^{L \times H}$,通过潜状态$x_k \in R^P$产生一个向量值的SSM输出(或预激活值)序列$y_{1:L} \in R^{L \times H}$。然后对该序列应用一个非线性激活函数,生成层输出序列$u'_{1:L} \in R^{L \times H}$。具体结构见图2b。与S4不同,S5不需要一个额外的逐位置线性层,因为特征已经在MIMO SSM内部混合。值得注意的是,与S4层中块对角SSM的HN潜状态大小相比,S5的潜状态大小P可以显著更小,这使得使用高效的并行扫描成为可能,我们将在3.3节中讨论。
对角化系统以实现并行扫描。S5层MIMO SSM的参数化是为了能够使用高效的并行扫描。如2.2节所述,要使用并行扫描高效计算线性递推,状态矩阵必须是对角矩阵。因此,我们对系统进行对角化,将连续时间状态矩阵写为$A = V \Lambda V^{-1}$,其中$\Lambda \in C^{P \times P}$是包含特征值的对角矩阵,而$V \in C^{P \times P}$对应于特征向量。这样,我们可以将公式(1)中的连续时间潜动力学对角化为:
重参数化系统。通过定义$\tilde{x}(t) = V^{-1}x(t)$,$\tilde{B} = V^{-1}B$和$\tilde{C} = CV$,我们得到一个重参数化的系统:
这是一个具有对角状态矩阵的线性SSM。
离散化和可学习参数。这个对角化的系统可以使用一个时间尺度参数$\Delta \in R^+$通过ZOH方法进行离散化,得到另一个对角化系统,其参数为:
在实践中,我们使用一个可学习的时间尺度参数向量$\Delta \in R^P$(见4.3节),并限制直通矩阵D为对角矩阵。因此,S5层的可学习参数为:$\tilde{B} \in C^{P \times H}$,$\tilde{C} \in C^{H \times P}$,diag(D) $\in R^H$,diag($\Lambda$) $\in C^P$,以及$\Delta \in R^P$。
初始化。先前的工作表明,深度状态空间模型的性能对状态矩阵的初始化很敏感(【21, Combining recurrent, convolutional, and continuous-time models with linear state space layers by Albert Gu et al., 2021, NeurIPS】;【20, Efficiently modeling long sequences with structured state spaces by Albert Gu, Karan Goel, and Christopher Re, 2021, ICLR】)。我们在2.2节讨论过,为了高效应用并行扫描,状态矩阵必须是对角矩阵。我们还在2.3节提到,HiPPO-LegS矩阵无法被稳定地对角化,但HiPPO-N矩阵可以。在第4节,我们将S5的动力学与S4联系起来,以说明为什么在MIMO设置中使用类似HiPPO的矩阵进行初始化也可能效果很好。我们的实验支持了这一点,发现对角化HiPPO-N矩阵能带来良好性能,并在附录E中进行了消融实验以与其他初始化方法进行比较。我们注意到,DSS(【25, Diagonal state spaces are as effective as structured state spaces by Ankit Gupta, Albert Gu, and Jonathan Berant, 2022, NeurIPS】)和S4D(【22, On the parameterization and initialization of diagonal state space models by Albert Gu, Karan Goel, Ankit Gupta, and Christopher Ré, 2022, NeurIPS】)层在SISO设置中也通过使用HiPPO-N矩阵的对角化取得了优异的性能。
共轭对称性。具有实数项的可对角化矩阵的复特征值总是成共轭对出现。我们通过使用一半数量的特征值和潜状态来强制实现这种共轭对称性。这确保了输出为实数,并将并行扫描的运行时间和内存使用量减少了一半。这个思想在【22, On the parameterization and initialization of diagonal state space models by Albert Gu, Karan Goel, Ankit Gupta, and Christopher Ré, 2022, NeurIPS】中也有讨论。
全时域循环计算。与S4层的大型HN有效潜状态大小相比,S5层的较小潜状态维度(P)使得在整个序列可用时能够使用高效的并行扫描。因此,S5层可以作为一种在时域中高效运行的循环模型,用于在线生成和离线处理。并行扫描和连续时间参数化还允许高效处理不规则采样的时间序列和其他时变SSM,只需在每一步提供一个不同的$\bar{A}_k$矩阵。我们在6.3节利用这一特性将S5应用于不规则采样的数据。相比之下,S4层的卷积方法要求系统是时不变的并且观测是规则间隔的。
计算复杂度匹配。S5的一个关键设计目标是匹配S4在在线生成和离线循环方面的计算复杂度。下面的命题保证了如果S5的潜状态大小$P = O(H)$,它们的复杂度在同一数量级。
命题1。给定一个具有H个输入/输出特征的S4层,一个具有H个输入/输出特征且潜状态大小$P = O(H)$的S5层,在运行时间和内存使用方面与S4层具有相同的数量级复杂度。
证明。证明见附录C.1。
实证对比。我们还在附录C.2中通过实证比较支持了这一命题。
我们现在建立S5和S4动力学之间的关系。在4.1节,我们展示在特定条件下,S5 SSM的输出可以被解释为由一个特定S4系统计算的潜状态的投影。这一解释为S5使用HiPPO初始化提供了动机,我们将在4.2节详细讨论。在4.3节,我们讨论了建立动力学关系所需的条件如何进一步启发初始化和参数化选择。
简化假设。我们在一些简化假设下比较S4和S5的动力学:
1. 假设1:我们只考虑H维到H维的序列映射。
2. 假设2:我们假设每个S4 SSM的状态矩阵是相同的,$A^{(h)} = A \in C^{N \times N}$。
3. 假设3:我们假设每个S4 SSM的时间尺度是相同的,$\Delta^{(h)} = \Delta \in R^+$。
4. 假设4:我们假设S5中使用与S4中相同的状态矩阵A(也参照假设2)。注意这也指定了S5的潜状态大小$P=N$。我们还假设S5的输入矩阵是S4使用的列输入向量的水平拼接:$B = [B^{(1)} | ... | B^{(H)}]$。
S5输出与S4潜状态的关系。我们稍后将讨论放宽这些假设,但在这些条件下,推导S4和S5动力学之间的关系是直接的:
命题2。考虑一个S5层,其状态矩阵为A,输入矩阵为B,以及某个输出矩阵C(参照假设1);以及一个S4层,其中每个H个S4 SSM的状态矩阵为A(参照假设2, 4),输入向量为$B^{(h)}$(参照假设4)。如果S4和S5层以相同的时间尺度离散化(参照假设3),那么S5 SSM产生的输出$y_k$等价于H个S4 SSM潜状态的线性组合,$y_k = C_{equiv}x_k^{(1:H)}$,其中$C_{equiv} = [C \cdot\cdot\cdot C]$。
证明。证明见附录D.2。
关系解读。重要的是,S5 SSM的输出不等于块对角S4 SSM的输出。相反,它们等价于修改了输出矩阵为$C_{equiv}$的块对角S4 SSM的输出。然而,在这些假设下,底层的状态动力学是等价的。回顾一下,用HiPPO初始化S4的动力学是其性能的关键(【20, Efficiently modeling long sequences with structured state spaces by Albert Gu, Karan Goel, and Christopher Re, 2021, ICLR】),命题2中建立的关系为S5使用HiPPO初始化提供了动机,我们现在就来讨论这一点。
使用HiPPO-N近似HiPPO-LegS。理想情况下,根据上述解释,我们应该用精确的HiPPO-LegS矩阵来初始化S5。不幸的是,如2.3节所述,该矩阵不能被稳定地对角化,而这对于S5使用的高效并行扫描是必需的。然而,【25, Diagonal state spaces are as effective as structured state spaces by Ankit Gupta, Albert Gu, and Jonathan Berant, 2022, NeurIPS】和【22, On the parameterization and initialization of diagonal state space models by Albert Gu, Karan Goel, Ankit Gupta, and Christopher Ré, 2022, NeurIPS】的实验表明,移除低秩项并用对角化的HiPPO-N矩阵进行初始化仍然表现良好。
理论依据。【22, On the parameterization and initialization of diagonal state space models by Albert Gu, Karan Goel, Ankit Gupta, and Christopher Ré, 2022, NeurIPS】为在单输入系统中使用这种正规近似提供了一个理论 justifications:在状态维度无限大的极限下,使用HiPPO-N状态矩阵的线性常微分方程(ODE)产生的动力学与使用HiPPO-LegS矩阵的ODE相同。利用线性性质,将这个结果扩展到S5使用的多输入系统是直接的:
推论1(【22】中定理3的扩展)。考虑如附录B.1.1中定义的$A_{LegS} \in R^{N \times N}$,$A_{NormalLegS} \in R^{N \times N}$,$B_{LegS} \in R^{N \times H}$,$P_{LegS} \in R^N$。给定向量值输入$u(t) \in R^H$,常微分方程$dx'(t)/dt = A_{NormalLegS}x'(t) + 1/2 B_{LegS}u(t)$在$N \to \infty$时收敛于$dx(t)/dt = A_{LegS}x(t) + B_{LegS}u(t)$。
HiPPO-N初始化的动机。我们在附录D.3中包含了这个扩展的简单证明。这个扩展为使用HiPPO-N初始化S5的MIMO SSM提供了动机。注意S4D(S4的对角扩展)也使用相同的HiPPO-N矩阵。因此,在命题2的假设下,一个S5 SSM实际上产生的输出等价于由S4D的SSM产生的潜状态的线性组合。我们在第6节的实验结果表明,用HiPPO-N矩阵初始化的S5的性能与用HiPPO-LegS矩阵初始化的S4一样好。
放宽对S4的约束。我们现在重新审视命题2所需的假设,因为它们只关联了一个受约束的S5版本和一个受约束的S4版本。关于假设2,【20, Efficiently modeling long sequences with structured state spaces by Albert Gu, Karan Goel, and Christopher Re, 2021, ICLR】报告说,具有绑定状态矩阵的S4模型仍然可以表现良好,尽管允许不同的状态矩阵通常会产生更高的性能。同样,根据假设3,要求所有S4 SSM使用单一的标量时间尺度是限制性的。S4通常为每个SSM学习不同的时间尺度参数(【23, How to train your HIPPO: State space models with generalized orthogonal basis projections by Albert Gu et al., 2023, ICLR】)以捕捉数据中的不同时间尺度。
通过增大潜状态P来放宽约束。为了放宽这些假设,注意假设4将S5的维度限制为$P=N$,而N通常远小于输入的维度H。命题1确定了S5可以通过$P=O(H)$来匹配S4的复杂度。通过允许更大的潜状态大小,可以放宽假设2和3,如附录D.4所讨论。我们还讨论了这种放宽如何为在对角线上使用HiPPO-N矩阵的块对角初始化提供了动机。最后,为了进一步放宽绑定的时间尺度假设,我们注意到在实践中,我们发现通过学习P个不同的时间尺度(每个状态一个)可以提高性能。关于这一经验发现的进一步讨论和消融研究见附录D.5和E.1。
与S4及其扩展的关系。S5与S4及其其他扩展的关系最为直接,我们已对此进行了深入讨论。然而,已有文献使用了与本文所开发思想类似的方法。例如,先前的工作研究了用连接有非线性层的线性RNN堆栈来近似非线性RNN,同时也使用了并行扫描(【45, Parallelizing linear recurrent neural nets over sequence length by Eric Martin and Chris Cundy, 2018, ICLR】)。【45】表明,几种高效的RNN,如QRNNs(【7, Quasi-recurrent neural networks by James Bradbury et al., 2017, ICLR】)和SRUs(【37, Simple recurrent units for highly parallelizable recurrence by Tao Lei et al., 2018, EMNLP】),都属于可以利用并行扫描的线性代理RNN类别。Kaul(【31, Linear dynamical systems as a core computational primitive by Shiva Kaul, 2020, NeurIPS】)也使用并行扫描来近似由离散时间单输入多输出(SIMO)SSM堆栈组成的RNN。
S4/S5的性能优势来源。然而,S4和S5是唯一显著优于其他可比较的SOTA非线性RNN、Transformer和卷积方法的模型。我们在附录E.2的消融研究表明,这种相对于先前并行化线性RNN尝试的性能提升很可能是由于连续时间参数化和HiPPO初始化。
我们现在通过实验比较S5层与S4层及其他基线方法的性能。我们将S5层作为S4层的直接替代品使用。模型架构包括一个线性输入编码器、堆叠的S5层和一个线性输出解码器(【20, Efficiently modeling long sequences with structured state spaces by Albert Gu, Karan Goel, and Christopher Re, 2021, ICLR】)。在所有实验中,我们选择S5的维度以确保其计算复杂度与S4相似,遵循3.3节讨论的条件,并具有可比的参数数量。我们展示的结果表明,S5层在性能和效率上都能与S4层相媲美。我们在附录中包含了进一步的消融研究、基线比较和运行时比较。
实验内容。LRA基准(【60, Long Range Arena: A benchmark for efficient transformers by Yi Tay et al., 2021, ICLR】)是一套包含六个序列建模任务的测试集,序列长度从1,024到超过16,000。该套件专为评测架构在长程建模任务上的性能而设计(更多细节见附录G)。
实验结果与分析。表1展示了S5在LRA上与其他方法的性能对比。在序列长度上具有线性复杂度的模型中(最著名的是S4、S4D以及同期的Liquid-S4【26】和Mega-chunk【43】),S5取得了最高的平均分。最显著的是,S5在Path-X任务上取得了所有模型(包括Mega【43】)中的最高分,该任务的序列长度是基准测试中最长的。
表1:LRA基准任务的测试准确率(【60】)。✗表示模型性能未超过随机猜测。我们在附录中的表7提供了包含完整引用和误差棒的扩展表格。我们遵循【20】和【22】中报告的程序,报告S4、S4D(如【20】; 【22】所报)和S5在三个种子上的平均值。粗体分数表示最高性能,<u>下划线</u>分数表示第二名。我们还包括了同期方法Liquid-S4(【26】)和Mega(【43】)的结果。与S4方法和S5不同,最好的Mega模型保留了Transformer的$O(L^2)$复杂度。
实验内容。Speech Commands数据集(【66, Speech Commands: A dataset for limited-vocabulary speech recognition by Pete Warden, 2018, arXiv】)包含35个词汇中不同朗读者朗读单词的高保真录音。任务是分类说出的是哪个单词。
实验结果与分析。如表2所示,S5的性能优于基线模型,优于之前的S4方法,并与同期的Liquid-S4方法(【26】)性能相当。由于S4和S5方法是连续时间参数化的,这些模型可以应用于不同采样率的数据集而无需重新训练,只需通过新旧采样率之比全局重新缩放时间尺度参数$\Delta$即可。表2还展示了将在16kHz数据上训练的最佳S5模型应用于8kHz采样(通过抽取)的语音数据上的结果,无需任何额外微调。S5在该指标上也优于基线方法。
表2:35分类Speech Commands任务的测试准确率(【66】)。我们在附录中提供了带误差棒的扩展表格,表8。训练样本为1秒16kHz的音频波形。最后一列表示在8kHz下的零样本测试(通过朴素抽取构建)。如【22】中所述,报告了三个随机种子的平均值。基线InceptionNet到S4D-Lin的性能引自【22】。
实验内容。我们研究的最后一个应用突显了S5如何自然地处理在不规则间隔接收到的观测数据。S5通过在每一步为离散化提供不同的$\Delta_t$值来实现这一点。我们使用由【4】和【57】提出的摆锤回归示例,如图3所示。输入序列是一系列L=50的24×24像素图像,这些图像被相关噪声过程破坏,并从一个持续时间T=100的连续轨迹中不规则采样。目标是摆锤角度的正弦和余弦,该角度遵循一个非线性动力系统。速度是未观测的。我们匹配了【57】的架构、参数数量和训练过程。
实验结果与分析。表3总结了该实验的结果。S5在回归任务上的性能优于CRU,获得了更低的平均误差。此外,在相同的硬件上,S5明显快于CRU。
表3:在保留测试集上,摆锤回归任务的回归MSE ×10⁻³(均值±标准差)和相对应用速度。基线mTAND到CRU的性能引自【57】。我们在附录中提供了扩展表格(表9)和更多细节。CRU(我们的运行)和S5的结果是二十个种子的平均值。
实验内容与结果。附录F.4中的表10展示了S5在其他常见基准测试上的结果,包括序列MNIST、排列序列MNIST和序列CIFAR(彩色)。我们看到S5的性能大致与S4相匹配,并优于一系列最先进的基于RNN的方法。
我们介绍了用于长程序列建模的S5层。S5层修改了S4层的内部结构,并用一种纯粹的、利用并行扫描的循环时域方法取代了S4使用的频域方法。S5在保持S4计算效率的同时实现了高性能。S5还提供了进一步的机会。例如,与卷积S4方法不同,并行扫描解锁了高效、简便地处理参数随时变化的时变SSM的能力。第6.3节展示了这一能力在处理可变采样率序列时的应用。同期开发的方法Liquid-S4【26】使用了一种输入依赖的双线性动力系统,突显了时变SSM的更多潜力。更通用的MIMO SSM设计也将有助于与经典概率状态空间建模以及近期在并行化滤波和平滑操作方面的工作(【56, Temporal parallelization of Bayesian smoothers by Simo Särkkä and Ángel F García-Fernández, 2020, IEEE Transactions on Automatic Control】)建立联系。更广泛地说,我们希望S5层的简单性和通用性能够扩大状态空间层在深度序列建模中的应用,并催生新的公式和扩展。
import jax
import jax.numpy as np
parallel_scan = jax.lax.associative_scan
def discretize(Lambda, B_tilde, Delta):
""" 离散化一个对角化的、连续时间的线性SSM
参数:
Lambda (complex64): 对角状态矩阵 (P,)
B_tilde (complex64): 输入矩阵 (P, H)
Delta (float32): 离散化步长 (P,)
返回:
离散化后的 Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
"""
Identity = np.ones(Lambda.shape[0])
Lambda_bar = np.exp(Lambda * Delta)
B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde
return Lambda_bar, B_bar
def binary_operator(element_i, element_j):
""" 线性递推并行扫描的二元运算符。假设A是对角矩阵。
参数:
element_i: 包含位置i的A_i和Bu_i的元组 (P,), (P,)
element_j: 包含位置j的A_j和Bu_j的元组 (P,), (P,)
返回:
新的元素 (A_out, Bu_out)
"""
A_i, Bu_i = element_i
A_j, Bu_j = element_j
return A_j * A_i, A_j * Bu_i + Bu_j
def apply_ssm(Lambda_bar, B_bar, C_tilde, D, input_sequence):
""" 给定一个LxH的输入,计算离散化SSM的LxH输出。
参数:
Lambda_bar (complex64): 离散化的对角状态矩阵 (P,)
B_bar (complex64): 离散化的输入矩阵 (P, H)
C_tilde (complex64): 输出矩阵 (H, P)
D (float32): 直通矩阵 (H,)
input_sequence (float32): 输入特征序列 (L, H)
返回:
ys (float32): SSM输出 (S5层的预激活值) (L, H)
"""
# 准备初始化并行扫描所需的元素
Lambda_elements = np.repeat(Lambda_bar[None, ...], input_sequence.shape[0], axis=0)
Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence)
elements = (Lambda_elements, Bu_elements) # (L, P), (L, P)
# 使用并行扫描计算给定输入序列的潜状态序列
_, xs = parallel_scan(binary_operator, elements) # (L, P)
# 计算SSM输出序列
ys = jax.vmap(lambda x, u: (C_tilde @ x + D * u).real)(xs, input_sequence)
return ys
def apply_S5_layer(params, input_sequence):
""" 给定LxH的输入序列,计算S5层的LxH输出序列。
参数:
params: 包含连续时间SSM参数的元组
input_sequence: 输入特征序列 (L, H)
返回:
S5层的输出序列 (L, H)
"""
Lambda, B_tilde, C_tilde, D, log_Delta = params
Lambda_bar, B_bar = discretize(Lambda, B_tilde, np.exp(log_Delta))
preactivations = apply_ssm(Lambda_bar, B_bar, C_tilde, D, input_sequence)
return jax.nn.gelu(preactivations)
def batch_apply_S5_layer(params, input_sequences):
""" 给定BxLxH的输入序列,计算S5层的BxLxH输出序列。
参数:
params: 包含连续时间SSM参数的元组
input_sequences: 一批输入特征序列 (B, L, H)
返回:
一批S5层的输出序列 (B, L, H)
"""
return jax.vmap(apply_S5_layer, in_axes=(None, 0))(params, input_sequences)
HiPPO-LegS矩阵。这里我们提供补充3.2节初始化讨论的额外细节。【23】通过将输入相对于一个无限长、指数衰减的度量进行分解,解释了S4在使用HiPPO-LegS矩阵时捕捉长程依赖的能力。HiPPO-LegS矩阵和相应的SISO输入向量定义为:
注意,在4.2节中,推论1中使用的输入矩阵$B_{LegS} \in R^{N \times H}$是由H个$b_{LegS} \in R^N$的副本拼接而成的。
NPLR/DPLR形式。【20】的定理1表明,【18】中的HiPPO矩阵$A_{HiPPO} \in R^{N \times N}$可以表示为正规加低秩(NPLR)形式,由一个正规矩阵$A_{NormalHiPPO} = V\Lambda V^* \in R^{N \times N}$和一个低秩项组成:
对于酉矩阵$V \in C^{N \times N}$,对角矩阵$\Lambda \in C^{N \times N}$,以及低秩分解$P, Q \in R^{N \times r}$。这个等式的右边显示HiPPO矩阵可以共轭变换为对角加低秩(DPLR)形式。
HiPPO-LegS的分解。因此,HiPPO-LegS矩阵可以表示为正规HiPPO-N矩阵和一个低秩项$P_{LegS} \in R^N$(【17】)的形式:
其中
S5的默认和块对角初始化。我们的默认设置是将S5层的状态矩阵设为$A = A_{NormalLegS} \in R^{P \times P}$,并取其特征分解$A = V \Lambda V^{-1}$来初始化$\Lambda$。然后使用V和$V^{-1}$来初始化$\tilde{B}$和$\tilde{C}$,如下所述。
如4.3节所述,我们还发现,在许多任务上,将S5状态矩阵初始化为块对角形式可以提升性能,其中对角线上的每个块等于$A_{NormalLegS} \in R^{R \times R}$,这里的R小于状态维度P,例如当对角线上使用4个块时$R = P/4$。然后我们对这个矩阵进行特征分解来初始化$\Lambda$,以及$\tilde{B}$和$\tilde{C}$。我们注意到,即使在这种情况下,$\tilde{B}$和$\tilde{C}$仍然以密集形式初始化,并且没有约束要求A在学习过程中保持块对角。在附录G的超参数表中,J超参数表示用于初始化的对角线上的HiPPO-N块的数量,其中J=1表示我们使用了默认的单一HiPPO-N矩阵初始化。我们在附录D.4中进一步讨论了这种块对角初始化的动机。
B和C的初始化。通常,我们使用初始状态矩阵对角化得到的特征向量来显式初始化输入矩阵$\tilde{B}$和输出矩阵$\tilde{C}$。具体来说,我们采样B和C,然后初始化(复数)可学习参数$\tilde{B}$为$\tilde{B} = V^{-1}B$,$\tilde{C}$为$\tilde{C} = CV$。
D的初始化。我们通过从标准正态分布中独立采样每个元素来初始化$D \in R^H$。
$\Delta$的初始化。先前的工作(【25】; 【23】)发现时间尺度参数的初始化非常重要。这在【23】中有详细研究。我们遵循S4的做法对这些参数进行采样,从区间$[\log \delta_{min}, \log \delta_{max})$上的均匀分布中采样$\log \Delta \in R^P$的每个元素,其中默认范围是$\delta_{min} = 0.001$和$\delta_{max} = 0.1$。唯一的例外是Path-X实验,我们从$\delta_{min} = 0.0001$和$\delta_{max} = 0.1$初始化,以考虑【23】中讨论的更长的时间尺度。
在图4中,我们展示了S4和S5层在高效、并行的离线处理中的计算细节比较。
命题1。给定一个具有H个输入/输出特征的S4层,一个具有H个输入/输出特征且潜状态大小$P = O(H)$的S5层,在运行时间和内存使用方面与S4层具有相同的数量级复杂度。
证明。我们首先考虑整个序列可用的情况,比较S4层的卷积模式和S5层的并行扫描。然后考虑在线生成的情况,此时两种方法都以循环方式运行。
因此,在两种情况下,S4和S5都具有同数量级的计算复杂度和内存需求。
S4, S4D与S5的运行时对比。表4提供了S4、S4D和S5在LRA任务的不同序列长度上的运行时性能(速度和内存)的经验评估。我们比较了S5的JAX实现与基于【55】的S4和S4D的JAX实现。为了公平比较,我们修改了S4和S4D的实现以支持共轭对称性和双向性。所有比较都在一个16GB NVIDIA V100 GPU上进行,并将S4D作为基准。
表4:在三个不同序列长度的LRA任务上,使用C.2节描述的参数化对S4、S4D和S5的运行时性能进行基准测试。对于速度,> 1×表示比S4D基线快。对于内存,< 1×表示使用的内存少于S4D基线。每个指标的第五行显示了用于表1中LRA结果的实际S5模型的性能,该模型使用了表11中报告的架构。
不同S5配置的分析。我们考虑了三种S5配置。前两种配置(第3、4行)显示了当S5的潜状态大小变化时运行时指标如何变化,而其他架构选择与S4相同。第3行中,$P=H$,这经验性地支持了附录C.1中的复杂度论证。第4行中,$P=N$,这对应于表5消融研究中与S4/S4D性能相似的受限S5版本。这两种配置的运行时结果都支持4.3节的论点,即S5的潜状态大小可以增加而保持与S4相当的计算效率。
最终S5模型的性能。最后,我们包括了第三种S5配置(第五行,斜体),它使用了表11中的最佳架构维度,并用于表1中相应的LRA结果。
结论。这项经验研究的总体结论是,S5和S4/S4D的运行时间和内存使用大体相似,正如正文中的复杂度分析所建议的。
我们现在更详细地描述S4和S5架构之间的联系。这种联系使我们能够开发性能更强的架构,并从现有工作中扩展理论结果。
我们将此分析分为三部分:
1. D.2节:我们证明命题2。利用系统的线性特性,我们确定由S5 SSM计算的潜状态等价于由H个SISO S4 SSM计算的潜状态的线性组合,并且S5 SSM的输出是这些状态的进一步线性变换。
2. D.3节:我们对【22】提供的证明进行简单扩展。原始证明表明,在SISO情况下,当N很大时,由(不可对角化的)HiPPO-LegS矩阵产生的动力学可以被(可对角化的)HiPPO-LegS矩阵的正规部分忠实地近似。我们将此证明扩展到MIMO设置。
3. D.4节:我们通过明智地选择S5状态矩阵的初始化,展示S5可以实现多个独立的S4系统并放宽所做的假设。
在以下各节中,除非另有说明,我们将使用以下假设:
1. 假设1:我们只考虑H维到H维的序列映射。
2. 假设2:我们假设每个S4 SSM的状态矩阵是相同的,$A^{(h)} = A \in C^{N \times N}$。
3. 假设3:我们假设每个S4 SSM的时间尺度是相同的,$\Delta^{(h)} = \Delta \in R^+$。
4. 假设4:我们假设S5中使用与S4中相同的状态矩阵A(也参照假设2)。注意这也指定了S5的潜状态大小$P=N$。我们还假设S5的输入矩阵是S4使用的列输入向量的水平拼接,$B = [B^{(1)} | ... | B^{(H)}]$。
命题2证明。
对于单个S4 SSM,离散化的潜状态可以表示为输入序列$u_{1:L} \in R^{L \times H}$的函数:
对于一个S5层,潜状态可以表示为:
其中我们索引$B = [B^{(1)} | ... | B^{(H)}]$和$u_i = [u_i^{(1)}, ..., u_i^{(H)}]^T$。
核心观察。我们观察到:
这个结果直接源于(13)和(14)的线性。这表明(在上述假设下)MIMO S5 SSM的状态等价于H个不同SISO S4 SSM状态的总和。
S5输出的推导。然后我们可以考虑输出矩阵C的影响。对于S5,输出矩阵是一个单一的密集矩阵:
将(15)代入(16)中,可以将MIMO S5 SSM的输出用H个SISO S4 SSM的状态表示:
记H个S4 SSM状态向量的垂直拼接为$x_k^{(1:H)} = [(x_k^{(1)})^T, ..., (x_k^{(H)})^T]^T$,我们看到S5 SSM的输出可以表示为:
因此等价于由H个S4 SSM计算的HN个状态的线性组合。
S4与S5输出矩阵的比较。为进行比较,我们假设每个S4 SSM的输出向量是S5输出矩阵的一行,即$C = [(C^{(1)})^T | ... | (C^{(H)})^T]^T$。每个S4 SSM的输出可以表示为:
我们可以定义在S4中作用于整个潜空间的有效输出矩阵(图2a中标记为C的虚线框)为:
通过检查(19)和(21),我们可以具体地表示出两层使用的等效输出矩阵的差异:
在S4中,有效输出矩阵由主对角线上的独立向量组成。相比之下,S5使用的有效输出矩阵则是在H个S4 SSM之间绑定了密集的输出矩阵。因此,S5可以被解释为仅仅定义了与S4使用的不同的对H个独立SISO SSM的投影。
推论1证明。命题2表明,用HiPPO-LegS矩阵初始化S5可能会产生良好性能。然而,HiPPO-LegS矩阵不能稳定地对角化。推论1允许我们用可对角化的HiPPO-N矩阵初始化MIMO SSM,以近似HiPPO-LegS矩阵,并期望性能相当。
【22】的定理3表明,对于标量输入信号,当$N \to \infty$时有以下关系:
我们希望将此扩展到向量值输入信号的情况。我们首先回顾(15),它表明MIMO S5 SSM的潜状态是H个SISO S4 SSM潜状态的总和。这在连续时间中同样适用:
因此,我们可以定义S5状态的导数为:
将(23)代入其中可得:
这种等价性为用可对角化的HiPPO-N矩阵初始化S5状态矩阵提供了动机,并表明我们可以期待看到类似的性能提升。
放宽假设。这里我们讨论如何放宽假设4中对S5潜状态大小的约束,以帮助放宽对绑定的S4 SSM状态矩阵(假设2)和时间尺度(假设3)的假设,以及命题2导致的绑定输出矩阵。
块对角S5 SSM。我们从S5 SSM状态矩阵是块对角的情况开始。考虑一个潜状态大小为$JN = O(H)$的S5 SSM,其A是块对角矩阵,B和C是密集矩阵,并有J个不同的时间尺度参数$\Delta \in R^J$。由于状态矩阵是块对角的,该系统有一个潜状态$x_k \in R^{JN}$,可以划分为J个不同的状态$x_k^{(j)} \in R^N$。然后我们可以将该系统划分为J个不同的子系统,并用其中一个$\Delta^{(j)}$来离散化每个子系统,得到以下离散化系统:
其中$A^{(j)} \in R^{N \times N}$,$B^{(j)} \in R^{N \times H}$和$C^{(j)} \in R^{H \times N}$。这个划分的系统也可以看作是J个独立的N维S5 SSM子系统,整个系统的输出是J个子系统输出的总和:
从命题2可知,这J个S5 SSM子系统中的每一个的动力学都可以与命题2中一个不同的S4系统的动力学相关联。这些S4系统中的每一个都有其自己的一组绑定的S4 SSM(参照假设2, 3)。重要的是,这J个S4系统中的每一个都可以有其自己的状态矩阵、时间尺度参数和输出矩阵在其H个S4 SSM中共享。因此,一个JN维S5 SSM的输出可以等价于命题2中J个不同S4系统潜状态的线性组合。这一事实为选择用几个HiPPO-N矩阵在块上初始化块对角S5状态矩阵提供了动机,而不是仅仅用一个更大的HiPPO-N矩阵进行初始化。在实践中,我们发现块对角初始化在许多任务上都能提高性能,见附录E。
学习P个时间尺度。最后,我们仔细研究时间尺度参数$\Delta$的参数化。如4.3节所述,S4可以为每个S4 SSM学习一个不同的时间尺度参数,这可能使其能够捕捉数据的不同时间尺度。此外,时间尺度的初始化可能很重要(【23】;【25】),而限制只采样一个初始参数可能导致初始化不佳。上一节的讨论为可能学习J个不同的时间尺度参数(每个子系统一个)提供了动机。然而,在实践中,我们发现在使用P个不同的时间尺度参数(每个状态一个)时性能更好。一方面,这可以简单地看作是为对角化系统中的每个特征值学习一个不同的缩放(见公式6)。另一方面,这可以看作是增加了初始化时采样的时间尺度参数的数量,有助于对抗初始化不佳的可能性。当然,系统可以通过将所有时间尺度设置为相同来学习只使用一个时间尺度。在附录E的消融研究中有进一步的讨论。
实验内容。表5展示了在LRA任务上进行的一项消融研究的结果,以更好地了解S5在不同设置下的性能。我们考虑了3个版本的S5。
1. 受限S5 (P=N, 标量∆): 使用与S4/S4D相同的通用架构,但S5潜状态大小P等于S4中每个SSM的N=64,并仅使用单个标量时间尺度参数。这基本上是我们在命题2中考虑的S5版本。
2. S5 (P=N, 向量∆): 与第一个版本完全相同,只是将时间尺度参数化为向量$\Delta \in R^N$。
3. 完整S5: 使用我们主要结果中报告的无约束版本,参数设置如超参数表11所示。这些模型允许使用较少的输入/输出特征H(以确保与S4基线有相似的参数数量),并通常使用更大的潜状态大小P > N。此外,我们对是否使用块对角初始化以及使用的块数进行了搜索。
结果分析。我们观察到,受限版本的S5在大多数任务上表现良好,但在Image和ListOps上难以与S4基线相媲美。使用向量时间尺度参数的版本在所有任务上都比标量版本有统一的提升。完整的S5模型在LRA任务上都受益于块对角初始化(见表11)。
表5:LRA基准任务的消融研究(【60】)。S4的结果取自【22】; 【20】。注意,此表中所有模型的总参数数量是相当的,因此性能变化不能归因于模型参数数量的大幅增加。
实验内容。我们进行了一项进一步的消融研究,以深入了解S5与先前并行化线性RNN(在第5节讨论)之间的差异,重点关注其区分特征:连续时间参数化和HiPPO初始化。我们比较了状态矩阵的不同初始化方法:随机高斯、随机反对称和HiPPO-N。此外,为了与更接近先前并行化线性RNN工作的设置进行比较,我们还考虑了S5的直接离散时间参数化,该参数化在训练期间不执行重复离散化,也不学习时间尺度参数$\Delta$。
结果分析。表6展示了这项消融研究的结果。主要结论是,唯一在所有任务上持续表现良好,并能解决Path-X任务的方法,是使用连续时间参数化和HiPPO初始化的S5方法。我们还注意到,离散时间/HiPPO-N矩阵配置由于稳定性问题难以训练,通常需要低得多的学习率。
表6:S5初始化和参数化消融研究。✗表示模型性能未超过随机猜测。
实验内容与分析。【22】提出了几种替代对角化HiPPO-N矩阵的对角矩阵,包括S4D-Inv和S4D-Lin矩阵。他们在LRA任务上进行了一项消融实验,简单地用S4D-Inv和S4D-Lin矩阵替换对角化的HiPPO-N矩阵,同时保持所有其他因素不变。我们在表7中包含了这些结果。在表7中,我们还包括了在S5中进行类似消融的结果,即用这些矩阵代替HiPPO-N矩阵来初始化S5,同时保持所有其他因素不变。两种矩阵在大多数任务上都表现良好,除了S4D-Lin矩阵在Path-X上的表现。有趣的是,其中一次运行达到了96.79%,但其他运行在此任务上并未超过随机猜测。未来对这些及其他矩阵的探索是未来工作的一个有趣方向。
表7:LRA基准任务的测试准确率(【60】)。✗表示模型性能未超过随机猜测。引用指的是原始模型。Transformer到Performer的结果来自【60】。我们遵循【20】; 【22】报告的程序,报告S4、S4D(如【20】; 【22】所报)和S5在三个种子上的平均值,括号内为标准差。
表8:Speech Commands分类任务(【66】)。35路关键词识别的测试准确率。训练样本为1秒16kHz采样的音频波形,或长度为16000的一维序列。最后一列表示在8kHz下的零样本测试,其中样本通过朴素抽取构建。报告了三个随机种子的平均值,括号内为标准差。基线InceptionNet到S4D-Lin的性能引自【22】。
实验内容。我们还评估了两个消融实验:
* S5-drop:使用相同的S5架构,但去除了对样本间间隔$\Delta_t$的依赖,即$\Delta_t=1.0$。我们预计这个网络表现会很差,因为它不知道观测之间经过了多长时间。
* S5-append:使用相同的S5架构,但将积分时间步长附加到三十维图像编码中,然后输入到密集的S5输入层。理论上,我们预计这个网络会和S5表现得一样好。然而,要做到这一点,需要S5网络学会处理时间,这可能很困难,尤其是在更复杂的领域。
结果分析。我们将这些消融实验的结果包含在表9的底部。注意,我们的S5实验会受益于JAX编译,但这不足以解释运行时间的差异。
表9:摆锤回归任务的测试MSE ×10⁻³和运行时。基线mTAND到CRU的性能引自【57】,为五个随机种子的均值和标准差(括号内为标准差)。我们重新运行了CRU(标记为CRU (our run))并在二十个随机种子上运行了我们的S5方法。我们报告了在保留测试集上的MSE误差的均值和方差,模型是使用验证集MSE选择的。
表10展示了像素级一维图像分类的结果和引用。
表10:像素级一维图像分类的测试准确率。引用指的是原始模型;附加引用表示该基线报告的出处。
并行扫描的组件。计算通用并行扫描需要定义两个对象:
* 扫描将操作的初始元素。
* 用于组合元素的二元结合运算符•。
线性递推的并行扫描定义。为了计算长度为L的线性递推$x_k = \bar{A}x_{k-1} + \bar{B}u_k$,我们将定义L个初始元素$c_{1:L}$,其中每个元素$c_k$是元组:
这些$c_{1:L}$将在扫描前预先计算。创建了扫描要操作的元素列表后,我们定义用于此线性递推的二元运算符•为:
其中$q_k$表示运算符的输入元素,可能是初始元素$c_k$或某个中间结果,*表示矩阵-矩阵乘法,@表示矩阵-向量乘法,+表示逐元素加法。
并行计算示例 (L=4)。考虑系统$x_k = \bar{A}x_{k-1} + \bar{B}u_k$和一个长度L=4的输入序列$u_{1:4}$。假设$x_0=0$,期望的潜状态为:
顺序扫描需要4个步骤。而并行扫描可以通过两步完成:
1. 第一步 (并行):计算$r_2 = c_1 \bullet c_2$和中间结果$q_4 = c_3 \bullet c_4$。
2. 第二步 (并行):计算$r_1 = r_0 \bullet c_1$, $r_3 = r_2 \bullet c_3$和$r_4 = r_2 \bullet q_4$ (其中$r_0=(I,0)$)。
这展示了并行扫描如何将顺序依赖性减少到对数级别,从而在长序列上实现显著加速。
结合性证明。最后,我们证明该二元运算符•是结合的:
左侧展开为:
右侧展开为:
两侧相等,证明了结合性。