专栏名称: 数据派THU
本订阅号是“THU数据派”的姊妹账号,致力于传播大数据价值、培养数据思维。
目录
相关文章推荐
大数据与机器学习文摘  ·  26岁OpenAI举报人疑自杀!死前揭Cha ... ·  2 天前  
玉树芝兰  ·  什么样的 AI 产品会更受用户欢迎? ·  5 天前  
51好读  ›  专栏  ›  数据派THU

​​当Batch Size增大时,学习率该如何随之变化?

数据派THU  · 公众号  · 大数据  · 2024-12-16 17:53

正文

本文约5700字,建议阅读15分钟

本文从多个视角讨论了 “Batch Size 与学习率之间的 Scaling Law” 这一经典炼丹问题。


随着算力的飞速进步,有越多越多的场景希望能够实现“算力换时间”,即通过堆砌算力来缩短模型训练时间。

理想情况下,我们希望投入  倍的算力,那么达到同样效果的时间则缩短为 ,此时总的算力成本是一致的。这个“希望”看上去很合理和自然,但实际上并不平凡,即便我们不考虑通信之类的瓶颈,当算力超过一定规模或者模型小于一定规模时,增加算力往往只能增大 Batch Size。

然而,增大 Batch Size 一定可以缩短训练时间并保持效果不变吗?

这就是接下来我们要讨论的话题:当 Batch Size 增大时,各种超参数尤其是学习率该如何调整,才能保持原本的训练效果并最大化训练效率?我们也可以称之为 Batch Size 与学习率之间的 Scaling Law。

01 方差视角

直觉上,当 Batch Size 增大时,每个 Batch 的梯度将会更准,所以步子就可以迈大一点,也就是增大学习率,以求更快达到终点,缩短训练时间,这一点大体上都能想到。问题就是,增大多少才是最合适的呢?

02 二次方根

这个问题最早的答案可能是平方根缩放,即 Batch Size 扩大到  倍,则学习率扩大到  倍,出自 2014 年的《One weird trick for parallelizing convolutional neural networks》[1],推导原理是让 SGD 增量的方差保持不变。

具体来说,我们将随机采样一个样本的梯度记为 ,其均值和协方差分别记为  和 ,这里的  就是全体样本的梯度。当我们将采样数目增加到  个时,有
即增加采样数目不改变均值,而协方差则缩小到 。对于 SGD 优化器来说,增量为 ,其协方差正比于 ,而我们认为优化过程中适量的(不多不少的)噪声是有必要的,所以当 Batch Size  变化时,我们通过调整学习率  让增量的噪声强度即协方差矩阵保持不变,从得出
这就得到了学习率与 Batch Size 的平方根缩放定律,后来的《Train longer, generalize better: closing the generalization gap in large batch training of neural networks》[2] 也认同这个选择。

03 线性缩放

有意思的是,线性缩放即  在实践中的表现往往更好,甚至刚才说的最早提出平方根缩放的《One weird trick for parallelizing convolutional neural networks》[1] 作者也在论文中指出了这一点,并表示他也无法给出合理的解释。

某种程度上来说,线性缩放更符合我们的直观认知,尤其是像《Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour》[3] 那样,假设连续的  个 Batch 的梯度方向变化不大的话,那么线性缩放几乎是显然成立的。

不过,这个假设显然过强,放宽这个假设则需要将 SGD 跟 SDE(随机微分方程)联系起来,这由《Stochastic Modified Equations and Dynamics of Stochastic Gradient Algorithms I: Mathematical Foundations》[4] 完成,但首先用于指出学习率与 Batch Size 的缩放关系的论文应该是《On the Generalization Benefit of Noise in Stochastic Gradient Descent》[5]。

事后来看,这个联系的建立其实并不难理解,设模型参数为 ,那么 SGD 的更新规则可以改写成:

其中  即为梯度的噪声,到目前为止,我们还没有对这个噪声的分布做任何假设,只知道它的均值为 ,协方差为 。接下来我们假设这个噪声的分布是正态分布 ,那么上述迭代可以进一步改写成

这就意味着 SGD 的迭代格式  实际上在近似地求解 SDE:

因此,要想在  发生变化时,运行结果不产生明显变化,上述 SDE 的形式应该不变,这就得到了线性缩放 。这个过程中最关键的一步是,SDE 的噪声项步长是非噪声项的平方根,从而分离出一项  来。

这一点我们在《生成扩散模型漫谈:一般框架之SDE篇》也有过评析,简单来说就是零均值的高斯噪声长期会有一定的抵消作用,所以必须增大步长才能将噪声效应体现出来。

以上结论都是基于 SGD 优化器得出的,论文《On the SDEs and Scaling Rules for Adaptive Gradient Algorithms》[6] 将它推广到了 RMSProp、Adam 等优化器上,结果是平方根缩放。

无独有偶,稍早一点的《Large Batch Optimization for Deep Learning: Training BERT in 76 minutes》[7] 在测试 Adam 及其变体 LAMB 时,也应用了平方根缩放。更多内容还可以参考博客《How to Scale Hyperparameters as Batch Size Increases》[8]。

04 直面损失

可以肯定的是,不管是平方根缩放还是线性缩放,它们都只能在局部范围内近似成立,因为它们都包含了“只要 Batch Size 足够大,那么学习率就可以任意大”的结论,这显然是不可能的。此外,前面两节的工作都围绕着方差做文章,但我们的根本任务是降低损失函数,因此以损失函数为导向或许更为本质。

05 单调有界

这个视角下的经典工作是 OpenAI 的《An Empirical Model of Large-Batch Training》[9],它通过损失函数的二阶近似来分析 SGD 的最优学习率,得出“学习率随着 Batch Size 的增加而单调递增但有上界”的结论。

整个推导过程最关键的思想是将学习率也视作优化参数:设损失函数是 ,当前 Batch 的梯度是 ,那么 SGD 后的损失函数则是 ,我们将最优学习率的求解视为优化问题:
这个目标显然很直观,就是选择学习率使得平均而言训练效率最高(损失函数下降得最快)。为了求解这个问题,我们将损失函数近似地展开到二阶:
这里的  就是 Hessian 矩阵,而  是损失函数的梯度,理想的目标函数是基于全量样本来求的,这也就是为什么它的梯度就是  的均值 。接着求期望,我们得到:
最后一项有少许技巧:

变换过程主要利用到了 。现在只要假定  的正定性,那么问题就变成了二次函数的最小值,容易解得:

这就得出了“随着  单调递增有上界“的结果,其中:
06 实践分析

当  时,,所以 ,即线性缩放,这再次体现了线性缩放只是小 Batch Size 时的局部近似;当  时, 逐渐趋于饱和值 ,这意味着训练成本的增加远大于训练效率的提升。

所以, 相当于一个分水岭,当 Batch Size 超过这个数值时,就没必要继续投入算力去增大 Batch Size 了。

对于实践来说,最关键的问题无疑就是如何估计  和  了,尤其是  直接关系到学习率的缩放规律和训练效率的饱和问题,二者的直接计算涉及到 Hessian 矩阵 ,其计算量正比于参数量的平方,在数亿参数量都算小模型的今天,计算 Hessian 矩阵显然是不现实的事情,所以必须寻找更有效的计算方式。

我们先来看 ,它的式子是 ,分子分母都有一个 ,这无疑有一种让我们将它们“约掉”的冲动。事实上简化的思路也是如此,假设  近似于单位阵的若干倍,那么得到:

 在计算上更为可行,并且实验发现它通常是  的一个良好近似,因此我们选择估计  而不是 。注意  只需要对角线上的元素,因此不用算出完整的协方差矩阵,只需要将每个梯度分量单独算方差然后求和。在数据并行场景,可以直接利用每个设备上算出来梯度来估计梯度方差。

需要指出的是,式(10)等结果实际上是动态的,也就是说理论上每一步训练的  都是不同的,所以如果我们希望得到一个静态的规律,需要持续训练一段时间,等到模型的训练进入“正轨”后计算的  才可靠的,或者也可以在训练过程中持续监控 ,以便判断当前设置与最优的差距。

至于 ,其实就没必要根据公式来估计了,直接在某个小 Batch Size 下对学习率进行网格搜索,搜出一个近似的 ,然后结合估计的  就可以反推出  了。

07 数据效率

从上述结果出发,我们还可以推导关于训练数据量和训练步数的一个渐近关系。推导过程也很简单,将(10)代入到损失函数中可以算得,在最优学习率下每一步迭代带来的损失函数减少量是:
其中 。接下来的重点是对这个结果的解读。

当  也就是全量 SGD 时,每一步损失函数减少量达到了最大的 ,这时候可以用最少的训练步数(记为 )达到目标点。

当  有限时,每一步的损失下降量平均只有 ,这意味我们需要  步才能达到全量 SGD 单步的下降量,所以训练的总步数大致上就是 

由于 Batch Size 为 ,所以训练过程消耗的样本总数则是 ,这是  的增函数,且当 ,这表明只要我们使用足够小的 Batch Size 去训练模型,那么所需要的总训练样本数  也会相应地减少,代价是训练步数  非常多。

进一步地,利用这些记号我们可以写出它们之间的关系是:

这就是训练数据量和训练步数之间的缩放规律,表明数据量越小,那么应该缩小 Batch Size,让训练步数更多,才能更有机会达到更优的解。

这里的推导是经过笔者简化的,假设了  和  在整个训练过程的不变性,如果有必要也可以按照原论文附录用积分更精细地处理动态变化的情形(但需要引入假设 ),这里就不展开了。

此外,由于 ,所以上式也提供了估计  的另一个方案:通过多次实验加网格搜索得到多个  对,然后拟合上式就可以估计出 ,继而计算 

08 自适应版

不得不说,OpenAI 不愧为各种 Scaling Law 的先驱之一,前述分析可谓相当精彩,并且结果也相当丰富,更难得的是,整个推导过程并不复杂,给人一种大道至简的本质感。

不过,目前的结论都是基于 SGD 来推的,对于 Adam 等自适应学习率优化器的适用性还不明朗,这部分内容由《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》[10] 完成。

09 符号近似

分析 Adam 的思路跟 SGD 一样,都是基于二阶展开,不同的是方向向量由  换成了一般的向量 ,此时我们有:

现在需要确定  以及计算相应的  和 。由于只需要一个渐近关系,所以跟《配置不同的学习率,LoRA还能再涨一点?》一样,我们选择 SignSGD 即  作为 Adam 的近似。这个近似的合理性体现在两点:
1. 无论  取何值,Adam 第一步的更新向量都是 
2. 当  时,Adam 的更新向量始终为 

为了计算  和 ,我们还需要跟“线性缩放” [11] 一节一样,假设  服从分布 ,而为了简化计算,我们还要进一步假设  是对角阵 ,即假设分量之间是相互独立的,这样一来我们可以独立地处理每一个分量。

由重参数得  等价于 ,因此:


这里的  是误差函数 [12],它是跟  类似的值域为  的  型函数,可以作为  的光滑近似。

但  本身没有初等函数表达式,所以我们最好找一个初等函数近似,才能更直观地观察变化规律,之前我们在《GELU的两个初等函数近似是怎么来的》就讨论过这个话题,不过那里的近似还是太复杂了(都涉及到指数运算),这里我们整个简单点的:

我们选择 ,使得这个近似在  处的一阶近似跟  的一阶近似相等。当然,都做了这么多重近似了,这个 c 的值其实已经不大重要,我们只需要知道存在这么个  就行了。基于这个近似,我们得到:

可以发现,Adam 跟 SGD 的一个明显区别是  这一步就已经跟  相关了。不过好在,此时的二阶矩更简单了,因为  的平方必然是 1,所以:

利用这些结果,我们就可以求得:

10 两个特例

相比 SGD 的式(10),Adam 的式(20)更为复杂,以至于无法直观看出它对  的依赖规律,所以我们从几个特殊例子入手。

首先考虑 ,此时 ,所以:
它跟 SGD 的  的区别是它关于梯度并不是齐次的,而是正比于梯度的 scale。

接着我们考虑  是对角阵的例子,即  时 ,此时:

这里求和的每一项关于  都是单调递增有上界的,所以总的结果也是如此。为了捕捉最本质的规律,我们可以考虑进一步简化 (这里开始跟原论文不一样):

这里的假设是存在某个跟  无关的常数 【比如可以考虑取全体  的某种均值,其实这里的  类似前面的 ,按照  的定义来估计也可以】,使得对任意  来说把  换成  都是一个良好近似,于是:

当  即  时,可以进一步写出近似:

这表明在 Batch Size 本身较小时,Adam 确实适用于平方根缩放定律。

11 涌现行为

如果我们将近似(24)应用到原始的式(20),会发现它存在一些全新的特性,具体来说我们有:

其中 ,以及:

注意  是  的单调递增函数,但式(27)最后的近似并不是  的单调递增函数,它是先增后减的,最大值在  取到。这意味着存在一个相应的 ,当 Batch Size 超过这个  后,最佳学习率不应该增大反而要减小!这便是原论文标题所说的 “Surge 现象”。

当然这里还有一个限制, 是始终小于 1 的,如果 ,那么最优学习率与 Batch Size 的关系依旧是单调递增的。

关于 Adam 的 ,其实 OpenAI 在论文附录中曾不加证明地“猜测” Adam 的最优学习率应该是:

其中 。现在看来,这个形式只是 Hessian 矩阵对角线元素占主导时的近似结果,当非对角线元素的作用不可忽略时,则有可能涌现出 “Batch Size 足够大时学习率反而应该减小”的 Surge 现象。

如何直观地理解 Surge 现象呢?笔者认为,这本质上是自适应学习率策略的次优性的体现。仍以近似  为例, 越大  就越准, 则是 ,然而  是最科学的更新方向吗?

不一定,尤其是训练后期这种自适应策略可能还有负面作用。因此,当  取适当值时, 的噪声反而可能修正这种次优性,而  继续增大时噪声减少,反而减少了修正的机会,从而需要更谨慎地降低学习率。

12 效率关系

同 SGD 的分析一样,最后我们还可以考虑 ,将式(27)代入式(21),恢复记号  然后化简(化简过程不需要任何近似)得到:

其中:

注意这里  是一个新的记号,它不是 ,后者是由  反解出来的理论最优 Batch Size,结果是:

它们之间的关系是:

由于式(30)形式上跟 SGD 的式(13)是一样的,所以那一节的分析同样适用,因此同样可以导出式(14):

只不过现在 。这样一来,我们就有得到一种估计  和  的方案:通过多次实验得到多个  对,实验过程中还可以顺便估计 ,然后拟合上式得到 ,继而估计 ,最后由(31)式解出 

如果 ,那么不存在最优的 ,如果  则说明 Hessian 矩阵对角线元素占主导,此时适用于缩放规律(25),增大 Batch Size 总可以适当增大学习率;当  时,可以由(33)解出最优的 ,Batch Size 超出这个值学习率反而应该下降。

13 补充说明

需要指出的是,上面几节分析的出发点和最终结论,其实跟原论文《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》[10] 大同小异,但中间过程的近似处理有所不同。

原论文得到的大部分结论,都是在  假设下的近似结果,所以得到 Surge 现象几乎总会出现的结论,这其实是不大科学的。

最明显的是  这个假设的形式本身就有点问题,它右端是跟i相关的,我们总不能给每个分量都配一个单独的 Batch Size,所以为了得到一个全局的结果就只能是 ,但这未免有点苛刻了。

本文的做法则是引入近似(24),这可以看成是平均场近似,直觉上比逐点的假设  更为合理一些,所以原则上结论会更为精准,比如可以得到“即使 Hessian 矩阵的非对角线元素不可忽略,Surge 现象也不一定会出现”的结论(取决于  )。

特别地,这种精准性并没有牺牲简洁性,比如式(27)同样很简明清晰,式(30)形式也跟原论文一致,并且不需要额外的近似假设,等等。

最后,稍微感慨一下,OpenAI 对 SGD 的分析其实已经是 2018 年的工作了,而 Surge 现象这篇论文则是今年中才发布的,从 SGD 到 Adam 居然花了 6 年时间,这是让人比较意外的,大体是 OpenAI 的“威望”以及猜测(29),让大家觉得Adam已经没什么好做了,没想到 Adam 可能会有一些新的特性。

当然, 作为 Adam 的近似究竟有多合理、能多大程度上代表实际情况等问题,笔者认为还值得进一步思考。

文章小结


本文从多个视角讨论了 “Batch Size 与学习率之间的 Scaling Law” 这一经典炼丹问题,其中着重介绍了 OpenAI 基于损失函数的二阶近似的推导和结论,以及后续利用同样的思想来分析 Adam 优化器的工作。

参考文献

[1] https://papers.cool/arxiv/1404.5997

[2] https://papers.cool/arxiv/1705.08741

[3] https://papers.cool/arxiv/1706.02677

[4] https://papers.cool/arxiv/1811.01558

[5] https://papers.cool/arxiv/2006.15081

[6] https://papers.cool/arxiv/2205.10287

[7] https://papers.cool/arxiv/1904.00962

[8] https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/

[9] https://papers.cool/arxiv/1812.06162

[10] https://papers.cool/arxiv/2405.14578

[11] https://kexue.fm/archives/10542#线性缩放

[12] https://en.wikipedia.org/wiki/Error_function



编辑:王菁



关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。



新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU