专栏名称: 深度学习自然语言处理
一个从大三就接触NLP的小小NLPer,本公众号每天记录自己的一点一滴,每篇文章最后也有托福单词等新知识,学技术同时,也一点一滴积累额外的知识。期待与你在知识的殿堂与你相遇!
目录
相关文章推荐
中核集团  ·  新春走基层 | 媒体镜头下的中核坚守者 ·  昨天  
中核集团  ·  别再做“法外之徒”了! ·  2 天前  
中核集团  ·  中核集团各系统召开2025年度工作会议② ·  3 天前  
中核集团  ·  中核集团各系统召开2024年度工作会议① ·  4 天前  
51好读  ›  专栏  ›  深度学习自然语言处理

可视化剖析与代码实践,带你一文掌握Mamba和SSM

深度学习自然语言处理  · 公众号  ·  · 2024-10-16 20:27

正文

整理:参天地化育

EMNLP2024分享会要开始啦!6大主题、2多主题,快来预约不错过

和其他具有长文本处理能力的组件一样,Mamba作为一种有可能超越Transformer的深度学习组件,最近一段时间以来,在很多应用中也确实获得了不错的表现。关于Mamba的介绍和报道层出不穷,而且受到了领域内的广泛关注,Mamba论文是23年12月在预印本上online,目前(24年4月)已经有190次引用!关于Mamba的科普帖子不少,我们打算用2篇公众号,向读者宣传一下Maarten Grootendorst大佬的这个技术科普博文A Visual Guide to Mamba and State Space Models。作者的文笔清晰,可视化做的非常棒,本来想翻译的,但是我想,AI翻译现在已经能较好的辅助阅读技术类文档和论文了,而且原文也很清晰容易理解,所以就不剥夺大家细读的乐趣了。本文是对这个很棒的博文的宣传和解读,稍加了一些个人思考。一切权力都归原作者所有。

这个博客引言部分大致介绍这篇博客的目的

In this post, I will introduce the field of State Space Models in the context of language modeling and explore concepts one by one to develop an intuition about the field. Then, we will cover how Mamba might challenge the Transformers architecture.

该博客分为如下几部分:Part I: The Problem with Transformers Part II: The State Space Model (SSM) Part III: Mamba — A Selective SSM

我们也按照作者的思路,在第一篇公众号的内容中介绍Transformer的缺陷,State Space Model的解决方案,下一篇公众号介绍Mamba和我自己的简单测试,代码测试这篇博文没有的。

开始正文吧。

Part I: The Problem with Transformers

总结来说:这部分分为四部分,论证了Transformer的优势与不足,从而提出研究新的架构和组件的必要性。

  • The core components of Transformers
  • A blessing with Training
  • And the curse with inference
  • Are RNN a solution?

The core components of Transformers

回顾了transformer是的输入输出和模型架构,这方面资料很多。关键概念token, encoder, decoder, attention mechanism,这里不再赘述。

A blessing with Training

由于注意力机制和Transformer架构带来的理解上下文和训练快的优势:

The advantages:

  • Transformer is capable of selectively and individually looking at past tokens
  • parallization enables training

And the curse with inference

Transformer在推理方面的不足,推理过程中,计算开销会随着token向量的长度线性增加!

Are RNN a solution?

RNN的问题,上下文遗忘,训练耗时;但是,

RNNs can do inference fast as it scales linearly with the sequence length! In theory, it can even have an infinite context length .

核心论点,Transformer训练快,但是推理慢,而且文本长度有限;RNN理论上可以应对无限文本,推理快,但是训练慢,容易遗忘上下文。

有无模型能结合两者的优势,实现既要,又要呢?

Part II: The State Space Model (SSM)

这部分作者有如下几个小标题:

  • What is State Space :抽象出State Space的概念,以及相关的一些计算。转态空间蕴含了上下文的信息。
  • What is a State Space Model: x,y,h
  • From a Continuous to a Discreate Signal:连续时间t如何转换为离散的k,实际上,我们采样的数据,都是离散的。
  • The Reccurent Representation (递推形式,对应推理)
  • The Convolution Representation(卷积形式,对应训练步骤)
  • The Three Representation(总结)
  • The Importance of Matrice A(这部分讲S4是怎么一回事)

What is State Space

抽象出State Space的概念,以及相关的一些计算。数据的latent state representation。

What is a State Space Model

,输入,输出,隐藏状态表示,如何实现,请看下图。

From a Contious to a Discrete Signal

离散化,因为实际中,我们大多数情况是对一个连续信号进行采样,因此,我们需要一个离散版本

这个版本的熟悉信号与系统的读者会立马想到LTI系统。

The Reurrent Representation

如果直接按照离散版本,得到的模型结构与RNN非常相似。

于是继承了RNN的优势与缺陷,推理步骤快,但是训练步骤慢。

The Convolution Representation

另外一种离散形式,如图,道尽一切。这里的卷积其实变为了了输入与Kernel的滑动窗口内积。

The Three Representations

SSM有,三种形式,连续形式,递归形式,卷积形式。在实际中,递归形式用于模型训练,递归形式用于模型推理。

The Importance of Matrix A

由于SSM中的矩阵A控制隐藏状态的转移,因此其十分重要。

如果我们要实现长程记忆,可以用一个叫HIPPO矩阵来作为A

这种具有long-range dependencies的SSM称为S4。S4 SSM具有如下三种特性,从而 具备了快速训练,快速推理,长文本信息处理和存储潜能

  • State Space Models
  • HiPPO for handling long-range dependencies
  • Discretization for creating recurrent and convolution representations

在前面介绍中,尽管S4 model 具备了快速训练,快速推理,长文本信息处理和存储潜能 ,但是相较于transformer,SSM还是没有一种类似注意力的机制, 使得模型能够有效从数据中提取与任务有关的信息,专注于完成给定任务 。而Mamba解决了这个问题。

Mamba的贡献是通过如下两点,从算法和硬件加速的两个角度,实现了状态空间模型的注意力机制,从而赋予了SSM以提取实现完成指定的任务的能力。

  1. A selective scan algorithm , which allows the model to filter (ir)relevant information
  2. A hardware-aware algorithm that allows for efficient storage of (intermediate) results through parallel scan , kernel fusion , and recomputation .

这部分作者有如下几个小标题:

  • What Problem does it attempt to Solve?
  • Selective Retain information
  • The Scan Operation
  • Hare-aware Algorithm
  • Mamba block

我们还是按照作者的顺序进行介绍

What Problem does it attempt to Solve?

在这一部分,作者通过对selecive copying和induction head两种任务来说明S4的不足.

Selective copying:根据给定的token,从序列提取相关的信息。

如图所示: 但是我们知道S4是LTI的,其支配状态转移的参数 均为固定值,因此不会随着的任务提示词的变化而变化。直白的说,假如用S4训练了一个模型,其可以完成从一篇输入的论文全文中提取题目,那么你再让它从输入的论文全文中提取作者信息它输出的还是论文题目。显然这是很蠢的。

Induction heads: 从输入的模式中发现规律。常见问题问答任务,例如下图所示:

总结来说,这都是因为SSM的模型结构是这样的:

但是这些任务对Transformer来说非常自然。因为其 矩阵会根据输入的序列改编其参数,并且专注于特定的任务。关于注意力机制,可以看李沐老师的《动手学深度学习》边写代码边理解。

因此需要改进SSM中的 ,使其可以随着输入而改变。

Selectively Retain Information

Mamba就是希望能够在较小的状态空间,可以实现和Transformer一样强大的能力。作者接着回顾了模型的输入,和S4模型的结果

Mamba增加了序列长度参数L和batch size两个参数,改进了步长,矩阵B和矩阵C。这样可以实现根据context awareness,但是不改变矩阵A,这样能保持状态数目不会太大。

由于B是随着输入动态变化的,在状态空间的改变过程中中,只有递归形式了。作者借鉴了parallel scan的思想,实现了状态转移计算部分的加速。

parralel更为细节的中文版科普见Mamba.py:扫描和并行扫描 - seed42的文章 - 知乎

Hardware-aware Algorithm

GPU有两种内存SRAM和DRAM两者之间的通信是耗时的,作者利用了一种叫做kernerl fusion的策略,从硬件上实现加速。细节可以看Mamba原文。

硬件这部分需要的背景知识比较多,相关中文介绍请看理解GPU中的各种内存类型,为什么SRAM和L1Cache被视为同一个内存区域 - Tim在路上的文章 - 知乎

Mamba block

Mamba作者在改进了SSM之后,设计了Mamba的block块,与配套函数。

其中最为核心的部分是Selective SSM, 它具备我们之前提过的如下几点特性:

  • Recurrent SSM created through discretization
  • HiPPO initialization on matrix A to capture long-range dependencies
  • Selective scan algorithm to selectively compress information
  • Hardware-aware algorithm to speed up computation

模拟输入输出展示:

最终是总结。Mamba实现了训练和推断的双重加速。

个人看法:尽管Mamba现在获得了领域的广泛关注,但是其能否真的超越Transformer,成为下一代大模型的替代组件,还有待进一步的深入研究。


附:

模仿《动手学深度学习》里transformer一节写的英语-法语翻译的"MambaFormer":

import mathimport pandas as pdimport torchfrom torch import nnfrom d2l import torch as d2limport inspectfrom mamba_ssm import Mambaclass PositionWiseFFN(nn.Module):    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,**kwargs):        super(PositionWiseFFN, self).__init__(**kwargs)        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)        self.relu = nn.ReLU()        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)    def forward(self, X):        return self.dense2(self.relu(self.dense1(X)))    class AddNorm(nn.Module):    def __init__(self, normalized_shape, dropout, **kwargs):        super(AddNorm, self).__init__(**kwargs)






请到「今天看啥」查看全文