channel_loss = {} for step, batch in enumerate(train_dataloader): batch = to_device(batch, device) channel = batch['channel'][0]
del batch['channel'] outputs = model(**batch) loss = outputs.loss
# Update channel loss if channel in channel_loss: channel_loss[channel][0] += loss.item() channel_loss[channel][1] += 1 else: channel_loss[channel] = [loss.item(), 1]
all_channel_loss = [None for _ in range(world_size)] torch.distributed.all_gather_object(all_channel_loss, channel_loss)
merged_channel_loss = {} for lst in all_channel_loss: for k, v in lst.items(): if k in merged_channel_loss: merged_channel_loss[k][0] += v[0] merged_channel_loss[k][1] += v[1] else: merged_channel_loss[k] = [v[0], v[1]]
for k,v in merged_channel_loss.items(): avg_loss = v[0] / v[1] if v[1] != 0 else 0.0 print_rank_0("The Channel {} loss is {}".format(k, avg_loss), args.global_rank)
# Log channel loss to TensorBoard if dist.get_rank() == 0: writer.add_scalar(f'Loss/channel_{k}', avg_loss, epoch * num_batches + step)
敲定 warmup 的数据比例后,选择一个顺眼的学习率和数据配比,就去开始训练和观察 channel loss 吧,在最理想情况下,我们期待得到一个这样的曲线:
domain_channel 的 loss 明显下降(新知识好学)
common_channel 的 loss 基本持平,极缓慢下降(理论上会选用作为底座的 model,通用能力已经很强了,这时候很难再让他的通用能力再进步一提升了,上文提到过 Qwen2 多训了 5T 通用数据但毫无收益)
结合 loss 曲线,我们再回过头来谈谈数据配比:post-pretrain 阶段最好的数据配比,就是沿用 pretrain 阶段的数据配比,很可惜,我们不可能获取到 Qwen、Llama 的 pretrain数据。因此,我们也别纠结数据去重了,大概率我们使用的 common 数据是人家已经训过的,我们尽可能去找质量最高的 common 数据喂给模型就可以了。
不过从 channel loss 上,我们大概率能观察和反推一些东西:
初始 loss 低:任务简单,或者模型已经训过这份数据。如果你使用的底座模型效果巨强,比如是 Qwen2-72B,Llama3-70B,你甚至可以断言这个数据的质量很高(能力差的小模型不能随便下定论)。当然,loss 低也有可能存在一种情况,那就是数据十分的脏,全都是重复 token 或者 固定 pattern;
初始 loss 高:好现象,说明模型没有见过这个数据。但也有数据质量很差的风险,最好再清洗下这个数据源;
loss 持平或缓慢下降:好现象,没有比这更好的现象了,基本就是我们蒙对了底座模型 pretrain 阶段使用的数据配比才会有的现象;
loss 快速下降:说明这个数据很容易学习,有可能是 domain 数据的特点比较显著,也有可能是数据比较脏,都是固定 pattern 或者具有明显的格式(提一句,Llama 说任何 markdown 数据都对模型性能有损失,所以有明显格式的数据要慎重使用);
common channel loss 下降明显:你的 common 数据显然不够 common,它相对模型来说有可能更像是 domain 数据,说明当前数据配比和 pretrain 的配比偏离有点远;
domain channel loss 下降明显:好事,鼓掌欢呼;
domain channel loss 不下降:初始 loss 低说明模型大概率已经训过这份 domain 数据了,初始 loss 高还不下降,可能是数据不够干净,也可能是数据比较难学,再多训会吧;
这篇 domain scaling law 的论文明确指出“domain能力“和”general 能力“是相互冲突的,也就回归到了我一开始说的:我们的目标不是提高通用能力,而是去损失尽量少的通用能力。
D-CPT:https://arxiv.org/pdf/2406.01375
D-CPT
这篇论文的结论都是比较 make sense 的:
小学习率,domain 学得快,通用忘得慢;
大学习率,domain 学得快,但到一定地步后就震荡,毕竟学习能力有限;
不同 size 的模型适合不同的学习率。
文章再多的内容我就不谈了,感兴趣的读者自己拜读一下即可,scaling law 的文章都相对晦涩一些,我还没有完全读懂,不敢班门弄斧。我引用这篇 sacaling law 论文的主要原因是,一是讴歌一下做 scaling law 的大佬们,二是想表达“学习率真的很重要”这一观点,不要因为大家都在强调数据质量的重要性,就忽略了炼丹的老本行。