专栏名称: 3DCV
关注工业3D视觉、SLAM、自动驾驶技术,更专注3D视觉产业的信息传播和产品价值的创造,深度聚焦于3D视觉传感器、SLAM产品,使行业产品快速连接消费者。
目录
相关文章推荐
湖北药监  ·  我国海洋经济总量首超十万亿元 ·  3 小时前  
宁夏药安早知道  ·  我国海洋经济总量首超十万亿元 ·  4 小时前  
甘肃省发改委  ·  我国海洋经济总量首超十万亿元 ·  6 小时前  
甘肃政务  ·  我国海洋经济总量首超十万亿元 ·  7 小时前  
地刊速览  ·  EPSL:古太平洋的缺氧事件 ·  昨天  
地刊速览  ·  EPSL:古太平洋的缺氧事件 ·  昨天  
51好读  ›  专栏  ›  3DCV

力压Transformer:全新架构Mamba详解

3DCV  · 公众号  ·  · 2024-05-13 18:21

正文

点击下方 卡片 ,关注 「3DCV」 公众号
选择 星标 ,干货第一时间送达

作者:知乎@绝密伏击|转自AI生成未来|编辑:3DCV

链接: https://zhuanlan.zhihu.com/p/684231320

背景

屹立不倒的 Transformer 迎来了一个强劲竞争者。

自 2017 年被提出以来,Transformer 已经成为 AI 大模型的主流架构,但随着模型规模的扩展和需要处理的序列不断变长,Transformer 的局限性也逐渐凸显。一个很明显的缺陷是:Transformer 模型中自注意力机制的计算量会随着上下文长度的增加呈平方级增长,比如上下文增加 32 倍时,计算量可能会增长 1000 倍,计算效率非常低。

为了克服这些缺陷,研究者们开发出了很多注意力机制的高效变体,但这往往以牺牲其有效性特为代价。到目前为止,这些变体都还没有被证明能在不同领域发挥有效作用。

而就在最近,一名为 Mamba 的架构似乎打破了这一局面。

与类似规模的 Transformer 相比, Mamba 具有 5 倍的吞吐量,而且 Mamba-3B 的效果与两倍于其规模的 Transformer 相当 。性能高、效果好,Mamba 成为新的研究热点。

图1 Mamba 在推理过程中的吞吐量对比

本文将详细的解读 Mamba 架构,由于 Mamba 是基于 SSM->HiPPO->S4->Mamba 演化过来的,而 HiPPO、S4、Mamba 的一作者都是卡内基梅隆大学机器学习系助理教授 Albert Gu。因此,本文将从标准 SSM 开始,逐步介绍 HiPPO、S4、Mamba。

图2总结了SSM、HiPPO、S4、Mamba的主要区别,以及各个模型的主要内容。本文内容也将按图中内容展开。

图2-2:HiPPO、S4、Mamba

一、现有架构问题

序列建模的核心问题是:同时解决 有效 高效 。有效是指能够选择性记忆历史信息,解决 长距离依赖 (Long-Range Dependencies,LRDs)问题;高效是指计算高效。

尽管传统的模型如循环神经网络(RNNs)、卷积神经网络(CNNs)和 Transformers 在处理长距离依赖方面有专门的变体,但它们在 处理超过 10000 步的极长序列时仍然面临挑战

1.1 Transformer 问题

Transformer 的一个主要优点是,无论它接收到多长的输入,它都使用序列中的所有 token 信息(无论序列有多长)来对输入数据进行处理。

图1-1:Transformer会查看过去所有 token

但是为了获得全局信息,注意力机制在长序列上非常耗费显存。注意力创建一个矩阵,将每个 token 与之前的每个 token 进行比较。矩阵中的权重由 token 对之间的相关性决定。

图1-2:Transformer 会计算每个 token 之间的 Attention

在训练过程中,Attention 计算可以并行化,所以可以极大地加快训练速度。但是在推理过程中,当生成下一个 token 时,我们需要重新计算整个序列的注意力。

图1-3:生成新 token 时需要重新计算整个序列的注意力

长度为 L 的序列生成 token 大约需要 L² 的计算量,如果序列长度增加,计算量会平方级增长。因此, 需要重新计算整个序列是 Transformer 体系结构的主要瓶颈

图1-4:Transformer 训练快、推理慢

1.2 RNN 的问题

图1-5:循环神经网络 RNN

在生成输出时,RNN 只需要考虑之前的隐藏状态和当前的输入。这样不会重新计算以前的隐藏状态,这正Transformer 不具备的。

这种结构可以让 RNN 进行 快速推理 ,并且理论上可以无限扩展上下文长度,因为每次推理只取一个隐藏状态和当前输入,内存占用非常稳定。

RNN 的每个隐藏状态都是之前所有隐藏状态的聚合。但是这里会有一个问题,在生成 token "Liang" 时,最后一个隐藏状态不再包含关于 token "Hello" 的信息。这会导致随着时间的推移,RNN 会忘记更久的信息,因为它只考虑前一个状态。

图1-6:只考虑前一个 hidden state

并且 RNN 的这种顺序性产生了另一个问题。训练不能并行进行,因为它需要按顺序完成每一步。

图1-7:RNN 训练不能并行

RNN的统一定义为:

其中 是每一步的输出,它由当前输入 和前一时刻输出 共同决定,而θ则是可训练参数。那么参数θ的梯度可以表示为:

可以看到,当前梯度依赖上个 token 的梯度。

与 Transformer 相比,RNN 的问题完全相反!它的 推理速度非常快,但不能并行化导致训练很慢

图1-8:RNN 和 Transformer对比

人们一直在寻找一种既能像 Transformer 那样并行化训练,能够记住先前的信息,又能在推理时时间是随序列长度线性增长的模型,Mamba 就是这样应运而生的 。解下来我们从 SSM 开始,逐步介绍 Mamaba。

二、状态空间模型 SSM

2.1 什么是 SSM

状态空间模型(State Space Models,SSM)由简单的方程(3)定义。它将一维输入信号 映射到N维潜在状态 ,然后再投影到一维输出信号

其中, 是状态转移矩阵, 是输入到状态的矩阵, 是状态到输出的矩阵,D是直接从输入到输出的参数(很多时候取 D = 0)。

2.2 SSM 架构

下图是 SSM 的架构,主要包含两个部分:状态更新方程和输出方程。

图2-1:SSM结构

SSM 可以简化为以下结构:

图2-2:简化的SSM结构

下面我们看一下更详细的结构,首先是状态更新,如下所示:

图2-3:状态更新详细结构

备注 :图中的输入 ,表示输入的信号是 D 维的。SSM 也可以用于处理多维信号输入。

然后是输出方程,详细机构如下所示:

图2-4:输出方程详细结构

2.3 SSM 例子:弹簧振子

下面举一个描述弹簧振子系统的 SSM 例子。

图2-5:弹簧振子

考虑一个质量为m的物体,它连接在一个劲度系数为k的弹簧上,并且受到阻尼系数为c的阻尼力作用。当物体从平衡位置偏离时,它会在弹簧力的作用下进行振动。我们可以用状态空间模型来描述这个系统的动态。

状态变量可以选择为物体的位移s(t)和速度v(t)。输入u(t)在这个例子中可以为零,因为我们没有外部力作用在物体上。输出y(t)可以是我们感兴趣的位移s(t)。

状态向量定义为:

输入向量为:

输出位移s(t)。弹簧振子的状态空间方程可以表示为:

在了解 SSM 基本概念之后,接下来我们介绍基于 SSM 的 HiPPO 架构。

三、HiPPO(High-order Polynomial Projection Operators)

HiPPO 是 Albert Gu 于2020年在论文 HiPPO: Recurrent Memory with Optimal Polynomial Projections 中提出的新架构。HiPPO 主要为了解决 如何在有限的存储空间中有效地解决序列建模的长距离依赖问题

HiPPO 通过函数逼近产生状态矩阵 A 的最优解,有效的解决了长距离依赖问题。

问题背景 :在处理序列数据时,一个核心问题是如何在增量方式下表示累积的历史信息。这涉及到如何在有限的存储空间中有效地更新和维护历史数据的表示。

HiPPO框架 :作者介绍了一个名为 HiPPO(High-order Polynomial Projection Operators)的通用框架,它通过将连续信号和离散时间序列投影到多项式基上,实现了在线数据压缩。

重要性度量 :HiPPO 框架考虑了一个度量,用于指定过去每个时间步的重要性。这个度量帮助HiPPO产生在线函数逼近问题的最优解。

理论贡献 :HiPPO 框架不仅提供了对现有记忆单元的简短推导,还推广了循环神经网络(如GRUs)中普遍存在的门控机制。

新的记忆更新机制 :作者提出了一个新的记忆更新机制(HiPPO-LegS),它能够随时间扩展以记住所有历史信息,避免了对时间尺度的先验假设。

理论优势 :HiPPO-LegS 具有时间尺度鲁棒性、快速更新和有界梯度的理论优势。

实验结果 :在基准测试中,HiPPO-LegS 在打乱的 MNIST 数据集上达到了98.3%的新最佳准确率。在一个新的轨迹分类任务中,HiPPO-LegS 在处理分布外时间尺度和缺失数据方面,比其他 RNN 和神经 ODE(一阶常微分方程)基线模型的性能提高了25-40%的准确率。

下面介绍 HiPPO 实现的具体细节。

3.1 HiPPO 架构:高阶多项式投影

3.1.1 HiPPO问题设置

问题定义

给定一个在时间t≥0上的输入函数u(t)∈R,需要在每个时间点操作累计历史 ,以便理解到目前为止看到的输入并对未来进行预测。

由于函数空间的庞大,无法完美记住整个历史,因此需要将其进行压缩,HiPPO 提出了将历史投影到有界维数的子空间的一半方法。

函数逼近与度量

为了评估逼近的质量,需要在函数空间中定义一个距离。任何在 上的概率度量μ可以为平方可积函数空间提供内积 ,从而诱导出一个希尔伯特空间 和相应的范数

为了选择合适的子空间,需要一个度量来量化历史的重要性。这个度量μ(t)随时间变化,支持在 上,因为 只在时间t之前定义。

多项式基展开

任何 N 维的函数子空间 G 都是逼近的合适候选。参数 N 对应于逼近的阶数,或者说压缩的大小;投影的历史可以通过G的任何基的N个系数来表示。

论文中使用多项式作为自然基,因此G是小于N阶的多项式的集合。

在线逼近

由于我们关心在每个时间 t 对 的逼近,我们也让度量 μ(t) 随时间变化。总体上,我们寻找一个 ,使得 最小。直观上,度量 μ 控制输入域各部分的重要性。

挑战

挑战在于如何在给定度量  μ(t) 的情况下以封闭形式解决优化问题,以及在 t -> ∞ 时如何线性地维护这些系数。

3.1.2 HiPPO 通用架构

通过连续动态系统计算投影

这部分是 HiPPO 的关键步骤,它涉及到将输入函数 u(t) 在时间 t 投影到一个多项式空间上,以便在线更新记忆表示。

投影的表示 :投影可以通过输入函数  u(t) 在时间 t 的限制 的 N 个系数来表示。这些系数是通过在多项式空间的基上展开 得到的。

正交多项式基 :为了选择合适的基,作者利用了正交多项式的性质。正交多项式为 u(t) 提供了一个自然的基,使得 的投影可以表示为这些基的线性组合。

系数的计算 :投影的系数x(t)是通过内积 计算得到的,其中 是正交多项式基的元素。

连续动态系统 :为了在线更新这些系数,作者提出了一个连续动态系统,这个系统描述了系数 x(t) 是如何随时间 t 变化的。这种动态系统可以表示为 ,其中 A(t), B(t)是依赖于时间的矩阵。

投影操作符 :作者定义了一个投影操作符 ,它将 映射到 G(多项式空间)中的 ,使得 最小化。这个操作符是 HiPPO 框架的核心。

系数提取操作符 :除了投影操作符,作者还定义了一个系数提取操作符 ,它将多项式g(t) 映射到其对应的系数 x(t)。

在线更新 :通过这个连续动态系统,HiPPO 框架能够在线更新记忆表示,即随着新数据的到来,系统能实时地调整系数x(t)。

在线函数逼近

图3-1:HiPPO框架

图2-6展示了 HiPPO 框架,首先需要找到投影 ,将输入u(t) 投影到多项式空间;然后将投影通过一组系数 来表示,这些系数捕捉了函数u(t) 的历史信息;使用连续时间下的一阶常微分方程来表示系数 如何随时间 t 动态变化;最后,将连续时间的动态变换转化为离散时间的递归关系(比如双线性变换),这允许 HiPPO 在每个时间步 k 更新系数

3.1.3 高阶投影:度量方法以及 HiPPO 动态系统

作者定义了两种度量方法,分别是 LegT 和 LagT。LegT 度量为最近的历史信息分配均匀的权重,表示如下:

LagT 度量使用指数衰减的方式来衡量历史信息的重要性,表示如下:

对于 LegT 和 LagT,系数x(t)可以使用 ODE(一阶常微分方程)来表示:

其中 A 和 B 是与度量 μ(t) 相关的矩阵。这个 ODE 描述了系数x(t)如何随时间 t 和输入函数 u(t) 变化。

备注:公式(9)是 HiPPO 框架的关键部分, 具体推导可以参看论文中的附录 D

对于 LegT 度量,矩阵 A 和 矩阵 B 可以表示如下:

对于 LagT 度量,可以表示如下:

3.1.4 HiPPO 框架中的连续时间动态转换为离散时间递归关系

由于我们处理的输入往往是离散的,因此我们需要将公式(9)的 ODE 离散化。ODE 离散化是一种常用的数据技术,它将连续时间的常微分方程转换为离散时间的差分方程。这通常涉及到选择一个合适的时间步长(或步长Δt),并使用数值方法(如欧拉方法、双线性)来近似连续微分。

图3-2:连续信号离散化

使用双线性离散化,如下所示:

结合公式(9)和公式(11),我们可以得到离散化的状态更新公式,表示如下:

离散化之后的 SSM 结构可以表示如下:

图3-3:离散化 SSM

在每个时间步长,我们计算当前输入( )如何影响前一个状态( ),然后计算预测输出( )。

图3-4:每个时间步的计算

这种表示看起来是不是有点熟悉?其实他的处理方法和RNN一样。

图3-5:离散化后和RNN类似

3.2 HiPPO-LegS

HiPPO-LegS 是作者基于新的度量提出的全新架构,具有时间鲁棒性、有界梯度、有界近似误差、长时间记忆等效果。

新的度量表示为

在新的度量下,矩阵 A 和矩阵 B 可以表示如下:

具体推导在论文的附录 D.3 部分。

更好的学习长期依赖

HiPPO-LegS 是专门为记忆而设计的,它通过其独特的结构和更新机制来避免梯度消失问题。LegS 通过使用Legendre 多项式作为基函数,并结合时间尺度不变的度量,来保持梯度的稳定性。

对于任何时间 ,HiPPO-LegS 在时间 的输出相对于时间 的输入的梯度范数为 ,这意味着梯度随着时间的增加而减小,但是衰减的速度比 RNN 的指数级慢的多。

这个性质使得 HiPPO-LegS 能够有效地缓解 RNN 中的梯度消失问题。即使在长序列中,梯度也不会迅速衰减到0,这有助于网络在训练中更好地学习长期依赖。

近似有界误差

HiPPO-LegS 在时间 t 的近似误差

其中 N 是多项式的最高阶。这表明随着多项式的阶 N 的增加,误差逐渐减小。

3.3 实验

将 HiPPO 和 RNN 相结合,当前状态 不仅和上一个状态 有关,还和 HiPPO 状态 有关,如下所示:

模型结构如下:

图3-6:HiPPO和RNN结合

下面是pMINIST 数据集上的结果,可以看到 LegS 的效果要好于 LagT 和 LegT,同时 HiPPO 的效果好于之前的其它模型。

图3-7:HiPPO实验结果

备注: pMNIST(permuted MNIST)是一个经过修改的MNIST数据集 ,它用于测试和评估机器学习模型在处理序列数据和学习长期依赖关系方面的能力。在 pMNIST 中,原始 MNIST 图像的像素被重新排列。这意味着图像的像素不再是按照自然顺序(从左到右,从上到下)呈现,而是按照一个固定的、随机的排列顺序。这种排列方式使得模型必须学习像素之间的长期依赖关系,而不能简单地依赖于局部空间结构。

四、S4 (Structured State Space Model)

S4 是 HiPPO 的后续工作,论文名称为:Efficiently Modeling Long Sequences with Structured State Spaces。

S4 的主要工作是将 HiPPO 中的矩阵 A(称为 HiPPO 矩阵 )转换为正规矩阵(正规矩阵可以分解为对角矩阵)和低秩矩阵的和,以此提高计算效率。

S4 通过这种分解,将计算复杂度降低到了O(N + L),其中 N 是 HiPPO 矩阵的维度,L 是序列长度。

在处理长度为 16000 的序列的语音分类任务中,S4 模型将专门设计的语音卷积神经网络(Speech CNNs)的测试错误率降低了一半,达到了1.7%。相比之下,所有的循环神经网络(RNN)和 Transformer 基线模型都无法学习,错误率均在70%以上。

下面我们就来介绍一下这篇工作。

4.1 HiPPO 解决了长期依赖

作者讨论了如何处理长距离依赖(Long-Range Dependencies,LRDs)的问题,LRDs 是序列建模中的一个关键挑战,因为它们涉及到在序列中跨越大量时间步的依赖关系。

作者指出,基本的 SSM 在实际应用中表现不佳,特别是在处理 LRDs 时。这是因为线性一阶常微分方程(ODEs)的解通常是指数函数,这可能导致梯度在序列长度上呈指数级增长,从而引发梯度消失或爆炸的问题。

为了解决这个问题,作者利用了 HiPPO 理论。 HiPPO 理论指定了一类特殊的矩阵 A,当这些矩阵被纳入 SSM 的方程中时,可以使状态 x(t) 能够记住输入 u(t) 的历史信息。这些特殊矩阵被称为 HiPPO 矩阵,它们具有特定的数学形式,可以有效地捕捉长期依赖关系。

HiPPO 矩阵的一个关键特性是它们允许 SSM 在数学和实证上捕捉 LRDs 。例如,通过将随机矩阵 A 替换为 HiPPO 矩阵,可以在序列 MNIST 基准测试上显著提高 SSM 的性能。

HiPPO 矩阵表示如下:

4.2 在线推理:使用递归形式

S4 在推理时,使用公式(12)的递归形式,每次只需要和上一个状态进行计算,具有和 RNN 相似的推理效率。

4.3 训练 S4:卷积表示

由于离散时间 SSM 的递归性质,它在硬件上进行训练时存在效率问题。因此,作者将离散时间 SSM 的 递归方程转换为离散卷积的形式 。通过展开递归方程,可以得到一个卷积核,这个卷积核可以用来在序列数据上应用卷积操作。这种转换允许 SSM 利用快速傅里叶变换(FFT)等高效的卷积计算方法,从而在训练过程中提高计算效率。

上面式子可以转化为卷积的形式:

其中, 是一个与 SSM 的参数(A, B, C)相关的卷积核,可以通过离散傅里叶变换(DFT)和逆变换(IDFT)来计算。这种卷积表示不仅在理论上是可行的,而且在实践中也是非常有效的,因为它允许在保持模型性能的同时,显著减少训练过程中的计算和内存需求。

作者在这一节中还讨论了如何计算 SSMn卷积核,这是他们技术贡献的关键部分。通过这种卷积表示,SSM 可以被有效地训练,同时保持其在处理长距离依赖(LRDs)方面的能力。这种表示形式为 SSM 在各种序列建模任务中的应用提供了灵活性,包括图像处理、语音识别和时间序列分析等。

图4-1:SSM 卷积核形式

下面是一个具体的例子,如何使用卷积核生成输出。

图4-2:使用卷积核生成输出

卷积的一个主要好处是它可以并行训练。但是由于核大小是固定,它们的推理不如 RNN 快速并且对序列长度有限制。

图4-3:递归 SSM 和 卷积 SSM 的对比

这里可以使用一个简单的技巧,即根据任务选择表示。在训练过程中使用可以并行化的卷积表示,在推理过程中,我们使用高效的循环表示。

图4-4:递归推理、卷积训练

4.4 为什么对角化可以减少 SSM 计算复杂度

为了进一步提升计算效率,作者讨论了对角化在计算离散时间状态空间模型(SSM)中的应用,以及为什么直接应用对角化方法在实践中并不可行。

对角化是一种线性代数技术,它可以将一个矩阵转换为对角形式,从而简化矩阵的乘法和其他运算。在 SSM 的上下文中,对角化可以显著减少计算复杂度,因为对角矩阵的幂运算(如在递归方程中出现的)可以通过简单的元素指数运算来完成。

下面我们解释下,为什么对角化可以减少 SSM 计算复杂度。

首先,我们引入论文中的定理 3.1

(Lemma 3.1):共轭是 SSM 中的等价关系,即:

也就是将矩阵 A,B,C 变为 , CV,最后得到的输出y保持不变。那么如果矩阵 是对角矩阵,则输出y的计算复杂度将从O( )变成O(N)。只要 Lemma 3.1 成立,我们就能使用对角化技术,降低计算复杂度。下面我们看一下 Lemma 3.1 的证明。

证明可以从 SSM 的两种表达形式出发。首先,有两个 SSM,其状态分别用 x 表示。

第一个 SSM:

第二个 SSM:

通过 V 将第二 SSM 乘以 V 后,变成如下形式:

可以看到当 ,两个 SSM 变得一样, 即 。因此,这两个 SSM 计算的相同的操作符 u -> y,只是通过 V 对状态 x 进行了变换。

通过共轭将 A 转换为其它形式,理想情况下这种形式结构更清晰,并且允许更快的计算。例如,如果 是对角矩阵,那么所需的计算将变得容易得多。

备注 :Lemma 3.1 是非常重要的结论,这意味着只要矩阵 E 和矩阵 A 相似(即满足 ),那么就可以使用 Lemma 3.1 中的方法替换矩阵 A,B,C,替换后的输出 y 保持不变。

4.5 直接对角化 HiPPO 矩阵导致数值溢出

上面提到,如果 A 能对角化,那么 SSM 递归计算的复杂度将从 变成 O(N)。

然而,作者指出,直接对角化 HiPPO 矩阵(用于处理长距离依赖的特殊矩阵)会导致数值问题。这是因为 HiPPO 矩阵的对角化涉及到的矩阵元素在状态大小 N 增大时会呈指数级增长,这使得对角化在数值上变得不稳定和不可行。

Lemma 3.2 将 HiPPO 矩阵 A 直接对角化为矩阵 ,那么 ,因此直接对角化会导致数值溢出。

下面证明下这个结论。

首先我们可以找到矩阵 A 的一个相似矩阵,表示如下:

其中:

那么可以找到一个可逆矩阵:

使得 是对角矩阵。比如考虑 N = 3,那么:

即然无法直接对 A 矩阵进行对角化,那么是否可以将其转化为低秩矩阵或者其它可以对角化的矩阵?解下来我们介绍下如何将其转换为正规矩阵+低秩矩阵。

4.6 S4 参数化:正规矩阵+低秩矩阵

虽然矩阵 A 不能直接对角化,但是可以表示为正规矩阵+低秩矩阵。

Theorem 1 :HiPPO 矩阵 A 可以表示为正规矩阵+低秩矩阵的形式,即:

其中 是对角矩阵, P, Q 是低秩矩阵。

下面简单证明下这个定理。

已知 HiPPO 矩阵 A 可以表示为:

那么 表示为:

虽然这个矩阵不是 反对称矩阵 (反对称矩阵可以对角化),但是可以表示为 ,其中 S 是反对称矩阵,矩阵 A 可以重新表示为:

由于 S 可以对角化,因此 也可以对角化,因此:

其中 是低秩矩阵。

备注 :之前整个矩阵 A 加上了 ,所以最后需要减去,而这一块更好是 。因为 第 n 行第 k 列为

这样我们就将矩阵 A 转换为了正规矩阵+低秩矩阵的形式。下面我们看一下转换之后的递归计算和卷积计算的复杂度。

4.7 S4 的计算复杂度

经过正规矩阵+低秩矩阵分解后,我们再来考虑 S4 的计算复杂度有什么变化。我们同时考虑推理时递归计算的复杂度以及训练时卷积计算的复杂度。

先给出结论:

  • S4 的递归计算复杂度为 。其中 MVM 表示矩阵向量乘法(Matrix-Vector Multiplications)。
  • S4 的卷积复杂度从 降低到 Cauchy 矩阵-向量乘法,空间复杂度为O(N + L)。

可以看到,递归计算的复杂度没有变化,而卷积的复杂度从 降低到 Cauchy 矩阵-向量乘法。

这里的 Cauchy 矩阵-向量乘法复杂度表示如下:

如果Cauchy 矩阵-向量乘法按照精确计算,那么 S4 的卷积复杂度为 ,也是小于之前的复杂度

解下来我们详细介绍 S4 计算复杂度的分析过程,首先介绍递归计算复杂度。

递归计算复杂度

公式(26)表示矩阵A可以分解为 ,结合前面的 Lemma 3.1,那么矩阵A可以变换为 ,可以表示为 的形式。

由于:

先考虑

然后计算

其中 也是对角矩阵,因此上式中的 相当于对低秩矩阵乘法 做了缩放,计算复杂度仍然为O(N)。

现在,我们可以将公式(30)重新表示为下面的形式:

公式(12)的 SSM 可以重新表示为:

可以看到矩阵 和矩阵 中的乘法运算都是矩阵-向量乘法(Matrix-Vector Multiplications,MVM)。因为它们都是对角矩阵+低秩矩阵,因此计算复杂度为O(N)MVM。

卷积计算复杂度

这一块就不再具体介绍了,感兴趣的可以直接去看原论文,在论文的附录 C.3 有详细的分析过程。

结论就是卷积的复杂度为: Cauchy 矩阵-向量乘法。

最后作者对比了 S4 和原始卷积、递归、Attention 之间的计算复杂度,可以看到 S4 是最低的,如下图所示:

图4-5:计算复杂度对比

图中 L 表示序列长度,B 表示 batch size,H 表示隐藏维度。







请到「今天看啥」查看全文


推荐文章
湖北药监  ·  我国海洋经济总量首超十万亿元
3 小时前
宁夏药安早知道  ·  我国海洋经济总量首超十万亿元
4 小时前
甘肃省发改委  ·  我国海洋经济总量首超十万亿元
6 小时前
甘肃政务  ·  我国海洋经济总量首超十万亿元
7 小时前
地刊速览  ·  EPSL:古太平洋的缺氧事件
昨天
地刊速览  ·  EPSL:古太平洋的缺氧事件
昨天
程序员大咖  ·  前沿开发团队的面试过程
8 年前
十点读书会  ·  穿越到民国过大年,是怎样一种体验?
8 年前
东方时代环球时事解读  ·  东方时事解读音频文字 : 152:何时半渡而击?
8 年前