23年12月UT Austin的论文“Early Weight Averaging meets High Learning Rates for LLM Pre-training”。
训练大语言模型(LLM)会产生巨大的成本;因此,任何加速模型收敛的策略都是有帮助的。本文研究了一个简单的想法-沿训练运行轨迹进行
检查点平均
-在训练过程中尽早提高收敛和泛化的能力。以高学习率训练的模型由于检查点平均观察到更高的增益。此外,在训练步骤中相当大的间隔采样检查点时,这些增益会被放大。这种训练方法优于传统训练和流行的检查点平均基线,例如指数移动平均 (EMA) 和随机移动平均 (SWA)。通过预训练LLMs来评估训练方案,由于一次批处理数量极大,高学习率本质上就是首选。具体来说,由 9B 个 tokens 组成的 OpenWebText 数据集上预训练不同大小的 nanoGPT-2 模型——小型 (125M)、中型 (335M) 和大型 (770M)。此外,还提供了公开可用的 Pythia LLM 的结果,范围从 1B 到 12B,这些结果在包含 207B tokens的 PILE-删除重复数据集上进行训练。代码可如下获取
github.com/sanyalsunny111/Early_Weight_Avg
建议在训练期间相对较早地以高学习率 (η) 执行模型权重的检查点平均。这一步骤背后的基本原理源于这样一个事实:检查点平均可以作为学习率(LR)衰减的替代,正如 Sandler 所证明的那样[33]。然而,这种替代 LR 衰减与优化过程中的权重更新无关,因为检查点平均是以事后方式进行的。利用这种简单的技术,在快速遍历 w2 的同时减轻了 w1 中的振荡,从而以更少的训练步骤实现增强的泛化,如图最近权重平均(LAWA)所示。
模型检查点权重平均的做法,广泛认为是在功能上类似于集成 [12, 40]。
在模型集成文献中,众所周知,多样的模型可以提高集成的性能[19]。
因此,可以合理地假设该原则也适用于模型平均。
定义一个模型在两个不同训练步骤中的多样性,即计算两个检查点之间的分歧量。
该方程计算来自同一支持集的样本数,该支持集中检查点 W1 与检查点 W2 二者的表现不一致。
Athiwaratkun 最近的一项研究。
已经证明较高的 LR 可以生成多样的模型检查点。
在训练步骤中对相距较远的检查点进行采样,可以进一步放大这种现象。
将这两种见解结合起来,可以在检查点中引入多样性。
LAWA 维护一个周期性采样检查点的先进先出 (FIFO) 队列,在两个连续样本之间有大量的干预步 (ν)。对 LAWA 进行细微的修改适应实验设置。具体来说,引入了 k_stepsize (ν) ,解耦间隔和 k,在训练运行中有效地采样远处的检查点。其LAWA算法Python风格伪代码如下:
本文进行所有的实验都利用自回归解码器-方式的大语言模型 (LLM),特别是 nanoGPT-2 和 Pythia LLM。使用三种不同尺寸的 nanoGPT-2 模型:小型 (125M)、中型 (355M) 和大型 (770M)。用 OpenWebText 数据集从头开始训练 nanoGPT-2 模型,其中包括 90 亿个训练tokens和 440 万个验证tokens。在整个实验过程中,保持一致的序列长度 1024 和每批次 131K 个tokens的固定批次大小,后者是GPU 容纳的最大批次大小。模型和预训练的配置改编自 Sophia 的 [23] AdamW 基线,并对学习率和批量大小进行了调整,以满足特定需求。值得注意的是,与[23]中的配置相比,训练所有模型的学习率高出十倍,批量大小高出两倍,其中学习率是通过网格搜索(grid search)调整的。将 LAWA 与原始预训练方法、EMA [35] 和 SWA [12] 进行比较,将其改编为LLMs。对于 EMA,根据 [14] 将衰减设置为 0.9,并在每一步更新 EMA 模型,这是标准做法。对于 SWA,坚持原来的预训练程序直到完成 75%之后,使用新的 SWA 调度程序(余弦退火)启动 SWA 训练。每 10 步计算一次 SWA 均匀的平均值。