经典长短时记忆网络(LSTM)架构最早可以追溯到20世纪90年代,
因其独特的常量误差传递(constant error carousel,CEC)和门控(gating)机制而在处理各种时序序列数据任务中展示出了卓越的性能
,尤其是在早期的大型语言模型(LLM)中发挥了关键作用。然而,随着Transformer架构的出现,其高度可并行化运行的自注意力机制使得模型可以拓展到更大规模的应用中,导致LSTM的地位逐渐被取代。
近日,
LSTM的原作者Sepp Hochreiter带队对LSTM框架进行了全新升级,重点针对LSTM缺乏并行处理能力以及在存储容量和灵活性上的缺陷进行了改进
,提出了一种称为xLSTM的全新架构。xLSTM提出了两种新的内存单元设计:
一种是使用标量内存和标量更新的sLSTM,它引入了新的记忆混合技术;另一种是mLSTM,它使用矩阵内存并能完全并行计算,采用协方差更新规则
。
作者通过实验证明,xLSTM与最先进的Transformer模型和状态空间模型(SSM)相比,显示出了优越的性能和良好的可扩展性。这表明,
通过对传统LSTM进行扩展和改进,xLSTM能够在大规模的语言模型中与当前的主流技术竞争
,甚至可能在某些情况下提供更优的性能和效率。
论文题目:
xLSTM: Extended Long Short-Term Memory
论文链接:
https://arxiv.org/abs/2405.04517
一、引言
Sepp Hochreiter与Jürgen Schmidhuber于1997年发表了题为《Long short-term memory》的论文,
其核心思想是常量误差传递和门控机制,解决了当时RNN中梯度消失的难题
,该文章在谷歌学术上的引用量达到了10w+。
多年来,LSTM在多个领域取得了重大成功,例如自然语言处理、语音识别、机器翻译等,但是其仍存在三个重大缺陷:
1. 无法修正已存储的决策(revise storage decisions)
,即模型在遇到更加相似的匹配向量时,无法修改之前存储的信息,作者通过最近邻搜索问题来展示这一现象,下图左侧展示了LSTM模型在该问题上的均方误差 (MSE),
在给定参考向量的情况下,模型需要顺序扫描序列以查找最相似的向量,当发现更加相似的向量时,LSTM无法对已存储的值进行修改
,导致MSE偏高。
2. 模型的记忆容量有限,只能存储标量级别的单元状态(必须先对信息进行压缩)
,作者通过低频token预测任务(Rare Token Prediction)来说明这一问题,如上图右侧所示,作者展示了LSTM模型在Wikitext-103数据集上的预测困惑度情况,由于记忆容量受限,LSTM的表现较差。
3. 由于记忆混合(memory mixing)设计而导致模型无法并行化计算
,具体表现为,当前时间步的hidden states必须借助于上一个时间步的结果进行顺序计算。
二、本文方法
为了解决上述限制,本文在LSTM架构的基础上提出了全新的xLSTM架构家族,整体框架如下图所示。作者首先设计了一种全新的新的内存混合方式,
称为sLSTM(Scalar LSTM),通过指数门控激活函数(exponential gates)和全新的归一化技术,来增强LSTM修正其存储决策的能力
。此外设计了一种
矩阵记忆和协方差更新规则
,具体体现在mLSTM(Matrix LSTM)模块中。
与Transformer等流行架构类似,
作者将xLSTM的整体架构设置为残差网络模式,即将sLSTM 和 mLSTM 集成到残差模块中,然后堆叠这些残差模块来形成整体网络
。
2.1 sLSTM
为了提高LSTM动态修正已存储决策的能力,作者在sLSTM中引入了指数门控激活函数,
传统LSTM中的门控函数使用Sigmoid函数,其值域为(0,1),当序列较长时,重复相乘会导致权重过度饱和
。为解决这一问题,可以将输入门
改为使用指数函数:
其中,
为输入门的前向值,由当前输入
、上一隐状态
及相应权重
计算得到,再通过指数函数得到实际输入门值
。
指数函数能产生任意大的正值,这允许LSTM可以修正先前的存储决策
。同时,遗忘门也可使用指数函数:
因此,sLSTM的前向传播过程可以表示如下,其中粉色框表示指数激活函数:
其中
是sigmoid函数,
是tanh激活函数,
为了防止指数激活导致数值上溢,作者专门引入了一种stabilizer state
来保证网络计算过程的稳定性
:
经过指数激活和稳定化处理后,sLSTM可以使用
和
替代
和
,
这样操作既不会改变网络输出和损失对参数的梯度,同时也能使网络对先前决策进行更新
。
2.2 mLSTM: 矩阵记忆和协方差更新规则
mLSTM的创新之处在于其将标量记忆单元
扩展为矩阵形式
来提高记忆存储容量
,并且利用了
协方差更新规则
存储键值对:
其中
为
的矩阵记忆单元,
和
分别为值向量和键向量,通过它们的外积计算可以实现新键值对的存储。
mLSTM的前向传播过程如上图所示,
其中第二行和第三行展示了对记忆单元的读取过程
,其中
为查询向量,通过与矩阵
的相乘得到输出
。
协方差更新规则最大化了二值向量的可分离性,使mLSTM获得了优秀的存储和检索能力
。与sLSTM不同,mLSTM内部没有单元间的记忆混合,因此具有完全并行性。
2.3 xLSTM架构
在得到sLSTM和mLSTM模块后,作者开始构建全新的xLSTM网络,
为了提高网络的非线性表达能力,作者将sLSTM和mLSTM集成到残差模块中
,构建了两种类型的xLSTM模块,如下图所示:
1. 残差sLSTM模块(Post Up-projection):
输入 --> sLSTM --> 门控MLP --> 残差连接 --> 输出
残差sLSTM模块如上图左侧所示,输入先经过sLSTM提取特征,
再通过一个门控的前馈网络提高表达能力,最后与输入相加构成残差连接,这一设计类似于Transformer
。
2. 残差mLSTM模块(Pre Up-projection):
输入 --> 升维MLP --> mLSTM --> 降维MLP --> 门控输出 --> 残差连接 --> 输出
如上图右侧所示,输入首先通过一个MLP映射到更高维度的空间,然后在该空间中使用mLSTM提取特征。接着通过另一个MLP将特征映射回原始维度,并使用门控单元对特征进行选择性传递。
最后将门控输出与输入相加构成残差连接。这种设计思路借鉴了状态空间模型(Mamba)
,目的是让mLSTM能够在高维空间中充分发挥其矩阵记忆单元的优势。
三、实验效果
本文的实验重点对xLSTM在语言建模方面的性能进行了评估,
首先在15B tokens的SlimPajama语料库上对xLSTM与多种模型(Transformer、状态空间模型、RNN等)进行了比较
于所有实验,均使用符号 xLSTM[a: b] 表示基于 mLSTM 与基于 sLSTM 的 xLSTM 块的 a/b 比率。例如,xLSTM[7:1]表示在 8 个块中,7 个是基于 mLSTM 的块,1 个是基于 sLSTM 的块。
实验结果如上表所示,
在验证集上,xLSTM取得了最佳的perplexity分数
,优于目前最先进的Transformer(GPT-3、Llama[1])、State Space模型(H3、Mamba[2])以及RNN模型(RWKV系列[3])。具体来说,xLSTM[1:0]获得了13.43的perplexity(困惑度),xLSTM[7:1]获得13.48,而其他模型的最佳分数在13.70-21.83之间。
此外,作者对
指数门控和矩阵记忆等新技术进行了消融研究
,结果如上表所示,结果清楚地表明,这两个新颖的设计都对xLSTM的优异性能做出了重要贡献。
为了评估xLSTM的扩展性能,
作者在上图中展示了本文方法在不同参数规模下的scaling行为,可以看到,xLSTM在几乎所有规模上都表现优异
,且随着模型越大,其与其他方法的性能差距也越大。为了更加彻底评估xLSTM在大规模语言建模任务上的潜力,
作者直接扩充了训练语料库的规模,在300B tokens的SlimPajama数据集上训练了xLSTM、RWKV-4、Llama和Mamba四种模型,模型大小从125M到1.3B参数不等
。
上图展示了在1.3B参数规模下,
多个模型在不同上下文长度下的perplexity表现,与其他模型相比,xLSTM能够在更长的上下文长度(最长16384)下保持较低的perplexity
,展现出了更加优异的长距离依赖建模能力。
四、总结
作者在总结中提到,xLSTM的提出回答了这样一个简单的问题:“当我们将经典的LSTM架构扩展到数十亿量级的参数时,我们在语言建模方面能走多远?”借助于提出的
指数门控激活、矩阵记忆和协方差更新规则等全新技术
,xLSTM已经实现了接近于目前流行的Transformer和状态空间模型等架构。此外,
作者通过缩放定律实验也证明了,xLSTM完全具有进一步拓展的能力
,并且有可能对其他深度学习领域产生重大影响,例如强化学习、时间序列预测或物理系统建模。
参考资料
[1] H. Touvron, T. Lavril, G. Izacard, X. Martinet, M.-A. Lachaux, T. Lacroix, B. Rozière, N. Goyal, E. Hambro, F. Azhar, A. Rodriguez, A. Joulin, E. Grave, and G. Lample. Llama: Open and efficient foundation language models. ArXiv, 2302.1397, 2023.
[2] A. Gu and T. Dao. Mamba: Linear-time sequence modeling with selective state spaces. ArXiv, 2312.00752, 2023.
[3] B. Peng, E. Alcaide, Q. Anthony, et al. RWKV: Reinventing RNNs for the transformer era. ArXiv, 2305.13048, 2023.