大家好,今天想来看一个和zero3权重切分方式相关的问题。之所以想来谈这个问题,是因为当前一个由deepspeed team官方给出的、传播度非常广的zero3运作流程视频解说,和它实际的代码实现间存在显著差异。
一、ds官方给出的zero3运作流程原理
我相信很多朋友,应该是通过这个官方视频介绍来入门zero的。这个链接里涵盖了一条视频,介绍了zero3整个fwd和bwd的过程。
官方视频链接:https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/
我们从视频中截取一张图进行说明(看到图以后是不是觉得很眼熟):
这里我们以上图为例,稍微解释下视频的内容:
- 一共有4张gpu,每张gpu吃一个micro batch的数据,它们共同构成一个dp组。
- 图中的M0, M1, M2, M3分别表示一个完整模型的4个部分,按照视频里的说法,这里对模型执行的是inter-layer(层间切割)。举例来说,假设模型一共有16层,那么M0 = layer0~3,M1 = layer4~7,M2 = layer8~11,M3= layer12~15。
- 图中刻画的是zero3 fwd的某一阶段,在这个阶段中,维护着M3的gpu3将把M3 broadcast 到其余gpu上,然后各张卡利用M3做fwd。
二、zero3实际的代码实践
而在zero3的代码实践中,模型权重其实是通过intra-layer(层内切割)的方式被放到各个gpu上的,换了一种切割方法,就会引起整个运作流程和通信方式上的极大不同。
zero3做partition_param的核心代码见下面链接。这里提一句,zero的代码写得非常复杂缠绕,从zero3的入口一直到这段切割的核心代码,我经历了漫长的阅读和跳转旅程。所以这里只放出核心代码,如果大家有阅读上下文的需求,需要耐心点从入口处开始慢慢读起。
zero3 partition param代码:https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/partition_parameters.py#L1551
我把整个zero3 param初始化和切割的过程抽象成下面这张图:
- 首先,某一块param(比如一个nn.linear层)先传送到rank0上
- 接着,这块param被从rank0广播到各个gpu上。此时每张gpu都拥有完整的rank0。
- 接着,把param展平成1D张量,同时做padding。padding的意思就是,如果一个param中所含的元素数量不能被dp_group内的gpu数量整除,那么就需要做padding对齐。
- 计算每一块gpu上所维护的1D param的范围(即start_index, end_index),然后按照这个范围取出每块gpu上应该维护的param chunk,接着就可以释放不是自己所维护的param chunk了。
在这个intra-layer形式的zero3权重切分下,fwd和bwd中涉及的关键通信应该是:
- fwd时,dp_group内做1次all-gather,取回完整的权重。
- bwd时,dp_group内做1次all-gather,取回完整的权重,以便做梯度计算。
- bwd时,dp_group内对梯度做1次reduce-scatter,让每张卡拿回属于自己的梯度,以便做权重更新。
如果按照第一部分ds官方给出的流程图,zero3的在fwd和bwd过程中主要通信方式应该是broadcast,事实上它也是这么注明的。而除了通信方式外,如果真的按照官方流程图的方式来写zero3,那么实践上将会大有不同。总结来说,一个小小的参数partition方式的变化,可能造成对整个zero3的认知变化。
三、自我鞭尸及差异分析
在开始写这节前,我先鞭尸我自己。
我在23年年初写过一篇关于deepspeed zero的介绍(https://zhuanlan.zhihu.com/p/618865052),但我对deepspeed的第一次了解要追溯到20年左右,那时还不是LLM的时代。我首次了解zero时,它只实现出了zero1/2,还没有做出zero3,也即zero3当时应该是存在于zero论文中的一个demo概念。当年我没有养成看代码的习惯,非常依赖官网的教程和各类blog的理论解读。
所以上文里提供的这份官方视频,就是我zero3的入门教程,事实上直到今年,它应该还是很多人入门zero会看到的东西,在B站或者各类blog上,有无数对它直接搬运或者重制的介绍。
在我23年年初开始写zero介绍时,我依然没有看代码,头脑里也还是官网介绍的那套模式。但是,误打误撞地,当我准备画图时,我发现画不同layer的broadcast实在太麻烦了,所以我把整个模型抽象成一个整体(one layer),然后用all-gather和reduce-scatter的方式替换掉原来的broadcast,替换的原因是我觉得在不影响对整体通信量分析的情况下,前两者表达起来更加合乎逻辑和直观。然后又误打误撞的,我随手把切割化成了intra-layer(毕竟抽象成one layer了)。最终,虽然头脑里对zero3的认知还是错误的,但是各种机缘巧合竟然画成了代码实际实践的样子: