Firefly
是开源的大模型一站式训练框架,支持对各种大模型进行
预训练、指令微调、DPO
,支持
全量参数、LoRA、QLoRA
等训练方式。支持包括但不限于 Gemma、Qwen1.5、MiniCPM、Mixtral-8x7B、Mistral、Llama 等绝大多数主流的大模型。
项目链接:
https://github.com/yangjianxin1/Firefly
模型权重:
https://hf.co/YeungNLP/firefly-qwen1.5-en-7b
https://hf.co/YeungNLP/firefly-qwen1.5-en-7b-dpo-v0.1
本文将分享我们使用 Firefly 项目对 Qwen1.5-7B 进行训练的实验。我们对训练数据进行
精细化筛选
,然后
在单张 V100 上进行 SFT 和 DPO
。经过两阶段的训练,我们的模型
在 Open LLM Leaderboard 上的表现显著优于官方的 Qwen1.5-7B-Chat、Gemma-7B-it、Vicuna-13B
等模型。
比
Qwen1.5-7B-Chat 高 7.12 分,比
Gemma-7B-it 高 8.8 分
。
通义千问 Qwen1.5 是阿里巴巴在春节前开源的大模型,支持 32K 的上下文长度,该模型本质上是 Qwen2 的 beta 版本,按照官方的说法,后续将会有
Qwen2 的正式版本
。从评测结果来看,
Q
wen1.5 各个尺寸的模型都显著优于同量级的 Llama2。
在 2 月份的 SuperCLUE 大模型榜单中,Qwen1.5 也有非常优秀的表现,在开源模型中处于引领者的地位。
大模型训练主要可以分为以下三大阶段:
-
预训练
: 使用超大规模文本对模型进行训练,训练任务为“预测下一个 token”,训练的数据量往往需要几万亿 token。
-
SFT (指令微调)
: 使用指令数据,让模型的输出格式与人类对齐,使其具备 chat 的能力。
-
RLHF
: 使用人类反馈或者偏好数据来训练模型,使模型的输出更加符合人类的价值观或者预期行为。
在 RLHF 阶段,以往的许多大模型,例如
Llama2、
InstructGPT 等,大多采用 PPO 来对模型进行价值观对齐训练。但是采用 PPO 进行 RLHF 存在流程繁琐、显存需求多(需要将策略网络、参考网络、critic 网络、奖励模型同时加载到显存中)等问题,这导致大部分普通玩家对其敬而远之。
使用 PPO 进行 RLHF 的主要流程大致如下:
-
构建奖励模型的训练数据
: 对于同一个 prompt 产生多个生成结果,对这些生成结果进行人工排序,两两一组,形成 chosen 和 rejected 的 pair。每条训练数据包含三个字段,prompt、chosen、rejected。
-
训练奖励模型
:
使用上述数据训练奖励模型,对于每条训练数据,训练目标为最大化 chosen 与 rejected 的奖励的差值。
-
PPO 训练
: 使用奖励模型的反馈对语言模型进行训练。
上面描述的 PPO 流程复杂且冗长,而 DPO 则绕过了奖励模型的构建,可直接使用人类偏好数据对模型进行训练,且在训练时仅需加载策略网络和参考网络,极大地节省了显存占用。训练数据包含
三个字段,prompt、chosen、rejected。
DPO 损失函数的计算过程也极具对称性,其公式如下所示:
对于上述公式,根据对数运算法则进行变换,在代码实现中,其计算过程大致如下:
-
计算对数概率:将 prompt 分别与 chosen 和 rejected 进行拼接,然后分别输入策略网络和参考网络,得到 4 个对数概率。
-
计算策略网络的 diff:策略网络的 chosen 对数概率 - rejected
对数概率。
-
计算参考网络的 diff:
参考网络的
chosen 对数概率
-
rejected
对数概率。
-
计算损失函数:
策略网络的 diff -
参考网络的 diff。
在 Qwen1.5-7B 的基础上,我们进行了 SFT 和 DPO 两阶段的训练,
整个训练流程仅使用一张 V100 GPU
,采用 QLoRA 技术,在所有 Linear 层都添加 adapter 以提升训练效果。
两阶段均使用英文数据进行训练
。我们与 Qwen1.5 官方的对话模板保持一致:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
hello, who are you?<|im_end|>
<|im_start|>assistant
I am a AI program developed by Firefly<|im_end|>
使用 Firefly 对 Qwen1.5 进行 SFT 的启动命令:
python train.py --train_args_file train_args/sft/qlora/qwen1.5-7b-sft-qlora.json
在 SFT 阶段,
实验参数设置如下:
num_epochs: 1
learning_rate: 2e-4
total_train_batch_size: 32
max_seq_length: 2048