专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
广告文案  ·  星巴克,重回一个定位 ·  2 天前  
51好读  ›  专栏  ›  极市平台

图解Mixtral 8 * 7b推理优化原理与源码实现

极市平台  · 公众号  ·  · 2024-03-13 22:00

正文

↑ 点击 蓝字 关注极市平台
作者丨 猛猿
来源丨大猿搬砖简记
编辑丨极市平台

极市导读

本文 焦点放在“Mixtral推理优化”这一块上,通过图解的方式,把代码的运作流程串起来,帮助大家更好理解原理和阅读源码。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

大家好,在写这篇文章时,本来是想打算介绍Mixtral 8 * 7b具体模型架构的。但是代码读着读着就发现:

  • 最精彩的MoE部分,其相关原理在之前的文章中已经详细介绍过
  • 整体来看Mixtral 8 * 7b的模型架构代码,写得非常清楚,几乎没有理解难点。

就在我以为Mixtral的代码已无更多可写时,我注意到了它在推理时用到的一些trick,具体为:

  • Sliding Window Attention (SWA,滑动窗口Attention)
  • Rolling Buffer Cache(也被称为Rotating Buffer Cache,即旋转式存储的KV cache)
  • Long-context Chunking (长上下文场景下的chunking策略,配合前两者食用)

这些trick的构思比较巧妙,同时代码实现并不好读,(特别是最后两个trick),表现在:

  • 没有注释。偶有注释举例的地方,例子举得并不好(进入了代码中assert非法分支,不适合用来做代码讲解。所以本文会给出更合适的例子做讲解)
  • 变量、class等命名较为晦涩
  • 所依赖的外部包(例如Xformers库)的官方文档给的介绍不够清晰
  • 逻辑较复杂

所以在这篇文章中 ,我们就把焦点放在“Mixtral推理优化”这一块上 ,同样通过图解的方式,把代码的运作流程串起来,帮助大家更好理解原理和阅读源码。在本文的最后一部分,给出一些源码阅读的hint(可能是大部分朋友在读Mixtral代码时感到最痛的点)。全文目录如下:

一、LLM推理两阶段
1.1 Prefill
1.2 Decode

二、Sliding Window Attention
2.1 原理
2.2 为什么能用滑动窗口

三、Rolling Buffer Cache
3.1 原理
3.2 "旋转"从何而来

四、Long-Context Chunking

五、Chunking全流程图解

六、一些关于源码的hint

一、LLM推理的两阶段

一个常规的LLM推理过程通常分为两个阶段:prefill和decode。

1.1 Prefill

预填充阶段。 在这个阶段中,我们 把整段prompt喂给模型做forwardi计算。如果采用 cache技术 ,在这个阶段中我们会把prompt过 后得到的 保存在cache_k和cache_v中 。这样在对后面的token 计算attention时,我们就不需要对前面的token重复计算 了,可以帮助我们节省推理时间。

在上面的图例中,我们假设prompt中含有 3 个token,prefill阶段结束后,这三个token相关的KV值都被装进了cache。

1.2  Decode

生成response的阶段 。在这个阶段中,我们 根据prompt的prefill结果,一个token一个token地生成response。

同样,如果采用了KV cache,则每走完一个decode过程,我们就把对应response token的KV值存入cache中,以便能加速计算。例如对于图中的t4,它与cache中t0~t3的KV值计算完attention后,就把自己的KV值也装进cache中。对t6也是同理。

由于Decode阶段的是逐一生成token的,因此它不能像prefill阶段那样能做大段prompt的并行计算,所以在LLM推理过程中,Decode阶段的耗时一般是更大的。

二、Sliding Window Attention

2.1 原理

从第一部分的介绍中,我们应该能感受到一点: LLM推理中的KV cache加速法,是非常典型的用“空间换时间”的操作。 随着seq_len变长,cache中存储的数据量也越来越大,对显存造成压力。

所以,我们自然而然想问:有什么办法能减缓cache的存储压力呢?

注意到, cache的存储压力之所以变大,是因为我们的Attention是causal decoder形式的,即每一个token,都要和它之前所有的token做Attention ,所以cache中存储的数据量才和seq_len正相关。如果现在我们转换一下思路, 假设每一个token只和包含其本身在内的前W个token做Attention,这样不就能把cache的容量维持在W吗? 而从直觉上来说,这样的做法也有一定的道理: 对当前token来说,距离越远的token,能提供的信息量往往越低, 所以似乎没有必要浪费资源和这些远距离的token做Attention。

这种Attention思路的改进,就被称为 Sliding Window Attention ,其中W表示窗口长度。这也是Mixtral 7b 和Mixtral 8 * 7b采用的方法,我们通过作者论文中的一张图,更清晰地来看下它和传统Attention的区别,这里W=3。

2.2 为什么能用滑动窗口

虽然滑动窗口的策略看起来很不错, 不过你一定有这样的疑惑:虽然距离越远的token涵盖的信息量可能越少,但不意味着它们对当前token一点用处都没有 。在传统的Attention中,我们通过Attention score,或多或少给这些远距离的token一定的参与度;但是在Sliding Window Attention中,却直接杜绝了它们的参与,这真的合理吗?

为了回答这个问题,我们来看一个例子,在本例中W=4,num_layers = 4,num_tokens = 10。

我们从layer3最后一个位置的token(t9)看起:

  • 对于 layer3 t9 ,它是由 layer2 t9 做sliding window attention得来的。也就是 layer3 t9 能看到 layer2 t6 ~ t9 的信息
  • 再来看 layer2 t6 ,它能看到 layer1 t3 ~ t6 的信息。也就是说对于 layer3 t9 ,它最远能看到 layer1 t3 这个位置。
  • 以此类推,当我们来到 layer0 时,不难发现,对于 layer3 t9 ,它最远能看到 layer0 t0 这个位置的信息。

欸你发现了吗! 对于 layer3 t9 ,虽然在每一层它“最远”只能看到前置序列中部分token,但是只要模型够深,它一定能够在某一层看到所有的前置tokens。

如果你还觉得抽象,那么可以想想CNN技术中常谈的“感受野” 。当你用一个固定大小的卷积窗口,对一张原始图片做若干次卷积,得到若干张特征图。越深的特征图,它的每一个像素点看到的原始图片的范围越广。 类比到我们的滑动窗口Attention上,从layer0开始,每往上走一层,对应token的感受野就往前拓宽W。

所以,Silding Window Attention并非完全不利用窗口外的token信息,而是随着模型层数的增加,间接性地利用起窗口外的tokens。

三、Rolling Buffer Cache

3.1 原理

当我们使用滑动窗口后,KV Cache就不需要保存所有tokens的KV信息了, 你可以将其视为一个固定容量(W)的cache,随着token index增加,我们来“滚动更新” KV Cache。

下图给出了Rolling Buffer Cache的运作流程:

在图例中,我们做推理时喂给模型一个batch_size = 3的batch,同时设W = 3。此时KV Cache的容量为 (batch_size, W) 。我们以第1条 prompt This is an example of ... 为例:

  • 在i时刻,我们对 an 做attention,做完后将 an 的KV值更新进cache中
  • 在 i + 1时刻,我们对 example 做attention,做完后将 example 的KV值更新进cache中。此时对于第1条prompt,它在KV cache中的存储空间已满。
  • 在 i + 2时刻,我们对 of 做attention,由于此时KV cache已满,所以我们将 of 的KV值更新进KV cache的0号位置,替换掉原来 This 的KV值。再后面时刻的token也以此类推。
  • 不难发现,prompt中第i个token在KV cache中的存储序号为: i % W

3.2 “旋转”从何而来

如果你读过Mixtral的源码,你可能会记得,在源码中管Rolling Buffer Cache叫Rotary Buffer Cache。 而“Rotary”这个词很值得我们关注:为什么叫“旋转”呢“

我们再回到3.1的图例中:

还是对于第一条数据,我们往上添两个单词,假设其为 This is an example of my last... 。现在来到了单词 last 上,我们需要对它计算Sliding Window Attention。

不难理解,在W=4的情况下, last 的Attention和 example of my last 相关。 现在我们把目光放到图中的KV Cache上:它的存储顺序似乎不太对,如果我们想对last做Attention,就要对当前KV Cache中存储的元素做一次“旋转”,将其转回正确的位置。

所以, Rotary的意思就是:通过某种规则,将Cache中的数据旋转回正确的位置,以便能正确做Attention。 这个规则在Mixtral源码中用一个 unrotate 函数来定义。在后文我们会详细看这个函数的运作方式。

四、Chunking

我们回忆一下目前为止Mixtral为了加速模型推理做的操作:

  • 使用KV Cache,加速Decode过程
  • 使用Sliding Window Attention和Rolling Buffer Cache,降低KV Cache存储压力

你可能已经发现, 这些以“空间换时间”的优化,都是针对Decode过程的。那么对于Prefill过程,我们能做什么优化呢?

相比于更耗时的Decode阶段, Prefill有一个更加突出的问题:long-context。过长的prompt会给显存带来压力。一个符合直觉的解决办法是:把prompt切成若干chunk,每次只喂给模型1个chunk,更新1次KV Cache。 这样我们虽然牺牲了一些Prefill计算的并行性(所有tokens一起计算),却能帮助我们节省显存压力(尤其是在采用sliding window attention的情况下,KV Cache的尺寸是固定的而不是随seq_len增长时)。

一般情况下,我们设 chunk_size = cache_window = sliding_window = W ,也就是chunk和cache的尺寸都和滑动窗口的尺寸保持一致,都设为W。对这个参数设置我们再说明下:一般满足cache_window = sliding_window,这个不难理解,因为cache中存的是attention感受野范围内的token。而chunk_size可以不等于这两者(源码中也提供了相关处理)。只是chunk_size和这两者相等时,无论是从计算逻辑还是空间利用率上,都是更好的选择(现在觉得抽象没关系,后文会提供具体的图例,大家可以感受下)。

好,现在我们来看一个chunking的图例(来自Mixtral论文),假设输入的prompt为 The cat sat on the mat and saw the dog go to ,同时 chunk_size = cache_window = sliding_window = 4

假设我们现在来到第三块chunk ,它包含的词为 the dog go to 。我们要对这个chunk中的每一个token计算滑动窗口Attention,同时把每个token的Xk, Xv值更新进KV Cache。

  • 图中row方向表示Xq ,即你可以把row方向 the dog go to 的每一个token,当成是这个token过Wq后的Xq值
  • 图中col方向表示Xk, Xv ,即你可以把col方向 The cat sat on the mat and saw the dog go to 的每一个token,当成是这个token过Wk,Wv后的Xk,Xv值,这些值存储在KV Cache中
  • 图中整个0/1数据块表示mask矩阵 。它表示row方向的Xq应该和col方向的哪些Xk,Xv值做attention。

现在我们已基本能理解这张图的含义, 不过还有一点很奇怪:在这个图下的Past, Cache, Current表示什么意思呢?

我们牢记一点:只有1个KV cache(也可以理解成只有1个用于存放Xk值的cache_k,和1个用于存放Xv值的cache_v) 。当我们遍历到某个chunk时,我们取出当前的cache和这个chunk做attention计算,然后再把这个chunk相关的KV值按Rolling Buffer Cache的方式更新进这个cache中。

回到我们的例子上,现在我们位于第3块chunk上,此刻cache中存储的Xk, Xv值,即是 上图中间块 维护的 the mat and saw 因此只有中间块的最底下被标上了“cache”,因为它才是此时真正的cache。 最左侧past块 维护的则是 前一个时刻的cache 最右侧的current块 维护的 the dog go to 即将被更新进cache的Xk, Xv值 。这就是past, cache和current的含义。

注意到虽然图中画出了past块,但这并不意味着计算第3块时要把past块也取出(此时past块代表的cache早就被更新了)。论文中这样画只是更方便我们了解cache更新迭代和计算的过程。( 悄悄吐槽下, 虽然论文中的这些图画得很好很精练,但是少了很多关键信息的文字介绍,容易给人造成似懂非懂的感觉)

五、Chunking推理全流程图解

我们用图解的方式把整个推理流程串一遍,好知道代码在做一件什么事情

5.1 输入数据

假设推理时 batch_size = 3 ,且有 chunk_size = cache_size = sliding_window = 4 ,则这个batch的prompts可表示成下图(每个方块表示1个token,同色方块属于同个prompt):

(1) chunk0

  • 我们首先将chunk0送入模型,此时KV cache为空

  • 对chunk中的每个token计算Xq,Xk,Xv,用于计算SWA(Sliding Window Attention)。图中刻画了计算时用到的mask矩阵 。在Mixtral源码中使用Xformers库的相关API来完成Attention相关的计算(这个库的好处是加速Attention计算)。 BlockDiagonalCausalMask (全称是BlockDiagonalCausalLocalAttentionMask)是这个库下提供的一种mask方法,它可以这样理解:

    • block :将矩阵进行分块(block),之后在每一个块内 单独 做Attention计算
    • diagonal causal :每一个block内做对角线mask

Xformers官方文档在这一块的介绍不太全面,对初次使用Xformers的朋友其实不太友好,所以在这里我做了可视化,方便后续大家对代码的理解。

  • chunk0的SWA计算完毕后,我们将每个token对应的Xk, Xv值存入cache。 在源码中,我们会通过一个规则确定每个token的KV值在KV cache中的存储位置,这样也方便我们做unrotate操作(见本文3.2部分)时能把cache中存储的元素旋转回正确的位置。
  • 最后,对于KV cache,它的 position序号 的排布顺序是从左至右,从上到下的,即:
Cache position index:

0 | 1 | 2  | 3
4 | 5 | 6  | 7
8 | 9 | 10 | 11

(2) chunk1

  • 对于chunk1中维护的tokens,我们正常计算他们的xq,xk,xv。
  • 取出当前KV Cache中存储的KV值,和chunk计算出来的KV值进行拼组,计算SWA (如图所示,mask矩阵的row行,每个色块由两部分组成:当前cache + 当前chunk)
  • 在计算SWA的mask矩阵时,我们同样采用Xformers库,这时调用的是 BlockDiagonalCausalLocalAttentionFromBottomRightMask 类,和chunk0调用的 BlockDiagonalCausalLocalAttentionMask 相比,它的主要不同在“FromBottomRight”上,也就是 对于每个block,它从右下角开始以窗口长度为W(本例中W=4)的形式设置mask矩阵。
  • 计算完chunk1的SWA后,我们将chunk1的KV值更新进KV Cache中

(3) chunk2

最后我们来看chunk2,这个chunk比较特殊,因为在这个chunk内,每一个prompt维护的序列长度是不一样的,3个prompt维护的tokens分别为 [[8, 9, 10, 11], [8, 9], [8]]

  • 同样,我们计算chunk2的每个tokens的Xq,Xk,Xv






请到「今天看啥」查看全文