专栏名称: 新智元
智能+中国主平台,致力于推动中国从互联网+迈向智能+新纪元。重点关注人工智能、机器人等前沿领域发展,关注人机融合、人工智能和机器人革命对人类社会与文明进化的影响,领航中国新智能时代。
目录
相关文章推荐
爱可可-爱生活  ·  【[52星]N8loom:基于树结构的前缀缓 ... ·  2 天前  
爱可可-爱生活  ·  【Stanford CS236 Deep ... ·  2 天前  
黄建同学  ·  学习-20250205192620 ·  2 天前  
机器之心  ·  AI「视觉图灵」时代来了!字节OmniHum ... ·  2 天前  
量子位  ·  DeepSeek华为火线联手!硅基流动首发即 ... ·  5 天前  
51好读  ›  专栏  ›  新智元

RNN回归!Bengio新作大道至简与Transformer一较高下

新智元  · 公众号  · AI  · 2024-10-25 13:03

正文



新智元报道

编辑:alan
【新智元导读】 近日,深度学习三巨头之一的Yoshua Bengio,带领团队推出了全新的RNN架构,以大道至简的思想与Transformer一较高下。

在Transformer统治的AI时代之下,

散落在世界各地的「RNN神教」信徒,一直相信并期待着RNN回归的那天:

毕竟,凭借强大的顺序和上下文感知能力,RNN曾在各种任务中表现惊艳。

直到后来遭遇了反向训练的瓶颈,因Scaling Law而跌落神坛。

然而,人们并没有忘记RNN。

RWKV、Mamba、xLSTM等RNN衍生模型接连出现,欲挑战Transformer之霸主地位。

就在近日,又有重量级人物下场——

深度学习三巨头之一的Yoshua Bengio,带领团队推出了全新的RNN架构,以大道至简的思想与Transformer一较高下。

论文地址:https://arxiv.org/pdf/2410.01201v1

研究人员对传统的两种RNN架构LSTM和GRU,进行了大刀阔斧的改造,从中诞生了两个新模型:minLSTM和minGRU。

这俩极简主义的版本到底怎么样?咱们先看疗效。

首先是RNN最大的问题:训练速度。

上图展示了几种模型在T4 GPU上训练花费的时间,以及新模型带来的加速比。横轴为输入数据的序列长度,批量大小为64。

可以看到,相比于原版的LSTM和GRU,minLSTM、minGRU和Mamba的运行时间不会随序列长度而增加(后3个模型的线在左图中重叠了)。

当序列长度为4096时,新架构相对于传统版本达到了1300多倍的加速比!

相当于原版GRU需要3年才能做完的事情,minGRU一天就搞定了。

那么对线Transformer的战绩如何?

在本文测试的语言建模任务中,minGRU和minLSTM分别在600步左右达到最佳性能点。

相比之下,Transformer需要比minGRU多花大概2000步,训练速度慢了约2.5倍。

对此,YC上的网友表示:「我非常喜欢这个新架构的简单性」。

毕竟,俗话说的好,「最好的PR是那些删除代码的PR」。

模型架构

下面来感受一下极简模型的诞生过程。

首先,这是传统的RNN架构:

LSTM在RNN的每个cell中加入了比较复杂的门控:

三个门控(input gate、output gate、forget gate)和输入的分量,都通过线性投影和非线性激活函数来得出,并且依赖于上一个时刻的隐藏状态ht-1。

这些值再经过线性和非线性计算,得到本时刻的输出ct和隐藏状态ht。

GRU在LSTM的基础上做了一些简化:

少了显式计算ct,用于门控的项也缩减到2个,相应的参数量和计算量也减少了。

那么我们就从相对简单的GRU入手,开始改造。

改造的目的是使RNN能够应用并行扫描(Parallel Scan)算法,解决自身训练困难的问题。

简单来说,就是将网络中的计算改造成vt = at ⊙ vt−1 + bt的形式。

minGRU

第一步,公式中含有对之前隐藏状态ht-1的依赖,没办法用并行扫描,所以把ht-1直接删掉。

ht-1没了,负责调控ht-1的rt也没用了,删掉。

第二步,双曲正切函数(tanh)负责限制隐藏状态的范围,并减轻因sigmoid(σ)而导致的梯度消失。

但是现在ht-1和rt都没了,tanh也失去了存在的意义,删掉。

那么最终,minGRU就是下面这三个公式:

相比于原版,参数量和计算量再次减少,最重要的是能够使用并行扫描来显著加快训练速度。

minLSTM

经过上面的叙述,minLSTM的由来就很好理解了。

首先还是去除隐藏状态的依赖:

接着是拿掉相关的tanh:

最后,为了保证LSTM输出的尺度与时间无关,以及hidden state在缩放上与时间无关,还需要删掉output gate。

output gate没了,ct也就没必要单独存在了,删掉;剩下的两个门控通过归一化来调配hidden state进入的比例。

——emmm......好像变成GRU了,算了不管了。

最终改造好的minLSTM是下面这个样子:

Were RNNs All We Needed?

全新的RNN搞出来了,能打Transformer吗?

别急,先打内战证明价值。

除了传统的RNN(LSTM和GRU),这里特别关注与Mamba的比较。

首先是训练上的提升:

实验在批次大小64的情况下改变序列长度,测量了模型执行前向传递、计算损失和向后传递计算梯度的总运行时间以及内存占用。

在运行时间方面,minLSTM、minGRU与Mamba实现了类似的效率。

序列长度为512时的运行时间(超过100次的平均值),分别为 2.97、2.72和2.71毫秒;序列长度为4096时,运行时间分别为3.41、3.25和3.15。

相比之下,LSTM和GRU的运行时间随序列长度线性增加。所以序列长度为512时,minGRU和minLSTM的训练加速了175倍和235倍;序列长度为4096时,加速比达到了1324和1361。

内存方面,利用并行扫描算法时会创建更大的计算图,所以minGRU、minLSTM和Mamba ,比传统RNN需要更多的内存(大概多出88%)。

——但这并不重要,因为对于RNN来说,训练时间才是瓶颈。

去除隐藏状态的效果

minLSTM和minGRU的训练效率是通过降低它们的门控对先前隐藏状态的依赖来实现的。

尽管单层minLSTM或minGRU的门控只与输入有关,而与时间无关,但是在深度学习中,模型是通过堆叠模块来构建的。

从第二层开始,minLSTM和minGRU的门也将与时间相关,从而对更复杂的函数进行建模。







请到「今天看啥」查看全文