专栏名称: GiantPandaCV
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaCV  ·  免费 | 抢先试用此芯Armv9 AI ... ·  3 天前  
GiantPandaCV  ·  美团基于SGLang提供INT8无损满血版D ... ·  3 天前  
51好读  ›  专栏  ›  GiantPandaCV

图解大模型训练系列:序列并行2,DeepSpeed Ulysses

GiantPandaCV  · 公众号  · 3D  · 2024-11-05 16:42

正文

大家好,在序列并行系列中,我们已经介绍过了 Megatron SP ,今天这篇文章我们来看DeepSpeed Ulysses。

在正文开始前, 请允许我吐槽一下,DeepSpeed Ulysses继承了DS家一如既往的写作和coding风格:云里雾里,梦里心里,就是走不进你的脑子里 。所以虽然paper短小,coding改动也小,一 切都慷慨地开源了,但一切又好像没有开源 ,使整个理解过程变得过于眼鼻酸涩。举些例子来说:

  • Ulysses的卖点之一【通讯量】竟然用一两句话就写过去了😢。
  • Ulysses SP的核心操作All2All过程,竟然用一个标着All2All的红箭头就概括过去了😢。
  • Ulysses + zero3这种官方安利的训练方法,竟然没有一个图例😢。
  • 诸如此类。

所以本来想偷懒不看源码,最终又要从源码开始看起。那既然说起了代码,如果你也看过ds家的代码风格的话,那你应该懂我接下来没有记录下的这些眼泪(但ulysses相关的代码其实还好)。

尽管有以上种种,当真正了解Ulysses的设计思想后,还是要佩服它的简便和轻巧 ,如果有做序列并行的需求,我大概率会从Ulysses开始尝试改起。

话不多说,这篇文章会尽最大可能补充ulysses的细节,全文目录如下:

一、Ulysses整体运作流程

二、Megatron VS Ulysses
2.1 Megatron通讯量
2.2 Ulysses通讯量
(1)All2All操作
(2)Ulysses fwd
(3)Ulysses bwd
2.3 通讯量对比

三、Ulysses + Zero3

四、参考

一、Ulysses整体运作流程

Ulysses的整体运作流程如下图,我们来详细解释下。

设:

  • N = seq_len
  • d = hidden_size
  • P = gpu_num ,在后文的解读中,我们可以发现ulysses在实际操作中,其实是1张卡算1个/若干个head的结果,所以这里还应该满足head_num是P的整数倍,不过接下来我们为了表达简便, 统一把这个P直接理解成head_num。

我们来跟着这张图,走一遍ulysses的fwd过程。

(1)按seq维度切分输入数据。

对于输入X =(N, d),我们将其切分成若干个seq_chunk,作为各自gpu的输入,每个seq_chunk的尺寸为(N/P, d)。

(2)每张卡计算自己维护的seq_chunk的qkv值。

  • 由于ulysses本身不对模型做任何切割,所以每块gpu上保存有完整的模型 ,也就是完整的 矩阵,尺寸都为 (d, d)

  • 这里额外提一点,我们经常会听到ulysses可以配合zero3进行使用。在这种情况下,在进入正式计算前,每块gpu确实只保存部分模型(模型并行的形式),但实际计算时会做all-gather让每张卡拿回完整的模型再计算(数据并行的实质),所以我们依然可以理解成gpu上保存完整的模型。

  • 每块gpu上的seq_chunk正常和 相乘,得到 q/k/v_chunk = (N/p, d)

(3)针对q/k/v_chunk,所有卡间做一次All2All通讯,使得每张卡拿到所有seq的某1个head的q/k/v_chunk

  • 做这个All2All通讯前,每张卡上维护的 q/k/v_chunk = (N/p, d) ,可以理解成某个seq_chunk所有head的qkv值
  • 做这个All2All通讯后,每张卡上维护的 q/k/v_chunk = (N, d/P) ,可以理解成所有seq的某个head的qkv值。

我们以q_chunk为例,来具体看All2All是怎么实现这一点的(下图根据ulysses源码进行绘制,做了一点简化)

  • 如上图所示,这里我们假设有4块卡(4个head),则最终我们希望gpu0算head0的结果,gpu1算head1的结果...以此类推。我们用不同颜色的矩形表示计算不同head需要用到的q数据。

  • 我们从上图的最左侧位于gpu0上的q0看起,它表示seq_chunk0的q结果,尺寸为(N/P, d)。不难理解,如果我们将q0沿着d维度切成P块,那么每一块就表示为了计算出对应的head所需要的q结果。其余gpu上的q_chunk也是类推。

  • 现在我们执行All2All算法,你可以将它理解成是一种“转置式”地通信方法 :结合上图我们可以发现,各块卡第1列蓝色块现在都跑去gpu0,第2列绿色块现在都跑去gpu1...这就是我们说的“转置”的含义。

  • All2All结束后,我们还以gpu0为例, 它上面拥有P块(N/P, d/P)数据,表示所有seq在head0上的q结果,我们将其稍作reshape后,每块卡上最终维护的q_chunk就变成(N, d/P)。 每块卡上的k/v_chunk也是同理进行All2All通讯。

(4)每张卡拿到所有seq的某1个head的q/k/v_chunk后,我们正常执行Attention计算 ,最终每张卡上产出结果 chunk,尺寸为 (N, d/P)

(5)针对 chunk,所有卡间再做1次All2All通讯,最终单卡上维护的P chunk尺寸又变回(N/P, d) 这个All2All过程可以理解成是先前描述的All2All的反操作,作用过程相似,这里不再赘述。

(6) 单张卡上拥有完整的 矩阵,我们将P chunk和它相乘,得到最后的输出O chunk,尺寸为 (N/P, d)

(7) 进入MLP层,由于在MLP层中,不涉及token和token之间的相关性计算,所以各seq_chunk块可以独自计算。

(8) 重复上述过程,直到算到Loss 为止。

  • 这里我初步判定,每张卡上算出的Loss应该就是这块卡所维护的那个seq_chunk的Loss。因为我粗看了一遍ulysses的代码,发现目前它的核心是单独设计了一个能实现sp并行的DistributionAttention的模块,然后用这个模块替换掉之前的Attention Module,通过这样一个简单的替换实现了ulysses的基本功能。再考虑到seq_chunk在MLP计算时的独立性和数据并行的特性,最终单卡Loss应该就是seq_chunk Loss,这也意味着sp组的梯度需要做AllReduce通讯,这个我们放在后面对ulysses的通讯量分析中再说。

二、Megatron VS Ulysses

不难发现,Ulysses和Megatron在分布式计算attention上有某些相似之处:

  • Megatron通过tp ,显式地把Wq, Wk, Wv切分开,然后每张卡上计算所有seq的某个head的结果。

  • Ulysses通过sp+all2all ,在每张卡完整保存Wq, Wk, Wv的前提下,让每张卡上计算所有seq的某个head的结果。

那么在实现相似功能的情况下, Ulysses提出的一个重要卖点是:我的通讯量低 。所以接下来,就让我们来详细分析这一点。

(⚠️⚠️⚠️:如果看到下文时,发现对通讯量、激活值等等计算有疑问的朋友,可以先看这篇写 Megatron SP的文章 。)

2.1 Megatron通讯量

上图展示了megatron tp + sp下的整体运作流程。

对于Attention部分:

  • 在fwd的过程中,做了1次all-gather,1次reduce-scatter

  • 在bwd的过程中,做了1次reduce-scatter,1次all-gather (其实在bwd反向传播到g前,还需要做1次all-gather,只是这个通讯量可以被计算掩盖掉,也就是还在传播到g前还在做上层的链式推导计算时,就可以开始all-gather了,所以我们这边先忽略这个额外的all-gather,但是如果你想算进去也没事)

对于MLP部分:

  • 同样是2次all-gather + 2次reduce-scatter,同理还有1次额外的all-gather可以被bwd过程中的计算时间覆盖掉,所以我们还是不计算它。

综合Attention和MLP:

  • 最终在Megatron中, Attention + MLP的通讯量为4 all-gather + 4 reduce-scatter(额外还有2次可以被bwd计算覆盖掉的all-gather不算在这里),1个all-gather/1个reduce-scatter的通讯量约为Nd(忽略batch_size),所以Megatron Attn部分的通讯量约为8Nd

2.2 Ulysses通讯量

(1)All2All操作的通讯量

  • All2All操作前,每张卡上保存的数据大小为 (N*d)/P ,每个小数据块的大小为 (N*d)/(P*P)

  • 虽然对于单卡来说,它的通讯量涉及send和accept,但整个系统的通讯量可以理解成是每张卡的send总和(因为你的accept总来自别人的send,反之亦然),所以对于单卡通讯量我们也只看send就行。

  • 对于单卡来说,它的send量 = ,约为 (N*d)/P 也就是单卡1次All2All的通讯量约是(N*d)/P 【回顾一下,单卡做1次all-gather或reduce-scatter的通讯量是N*d】。

(2)Ulysses fwd通讯量

回顾第一部分Ulysses的fwd过程:

  • q/k/v_chunk各自做1次All2All通讯,则这里合起来做了3次All2All通讯
  • 各卡上原始的Attention结果 做了1次All2All通讯
  • 综上,Ulysses fwd过程一共做了4次All2All通讯

(注意,如果配套使用了zero操作,fwd过程还会涉及模型权重的all-gather,不过这里我们不考虑这一点,我们就假设是最朴素的ulysses,单卡上有完整的模型)

(3)Ulysses bwd通讯量

这块也是我觉得比较重要,但是论文里没有展开的地方(抹泪),所以我翻了翻源码(又抹泪),根据自己的理解大致描述下bwd的过程。

对于ulysses bwd的过程,对于以下两类通讯,由于理论上它们可以被bwd的计算时间覆盖,所以不计入总通讯量中:

  • 激活值的重计算 :我们知道链式推导的过程中我们会用到一些激活值(比如上图中的P),而为了节省显存,大部分框架都不会把这些激活值保存下来,只有在链式传导块传递到这个激活值上时,重新做fwd算出这个激活值。比如链式传导快到上图中的P上时,我们就需要重做fwd的All2All把P算出来。我们可以在用到P前做这件事,因此重算P的All2All通讯是可以被覆盖的。图中的 等等也是同理。

  • 梯度的AllReduce:

    • 假设我们现在要对 计算梯度,这里假设我们有两张卡,每张卡上维护某个seq_chunk的P_chunk结果,P_chunk的尺寸为(N/2, d),则我们有:

    • 每张卡上所维护的seq_chunk最终的loss为:







请到「今天看啥」查看全文