专栏名称: Hugging Face
The AI community building the future.
目录
相关文章推荐
BioArt  ·  【DeepSeek专栏】Nat ... ·  昨天  
BioArt  ·  Nat Cell Biol | ... ·  3 天前  
BioArt  ·  Cell Metab | ... ·  4 天前  
51好读  ›  专栏  ›  Hugging Face

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

Hugging Face  · 公众号  ·  · 2024-03-08 22:29

正文

01

简介

Firefly 是开源的大模型一站式训练框架,支持对各种大模型进行 预训练、指令微调、DPO ,支持 全量参数、LoRA、QLoRA 等训练方式。支持包括但不限于 Gemma、Qwen1.5、MiniCPM、Mixtral-8x7B、Mistral、Llama 等绝大多数主流的大模型。


项目链接: https://github.com/yangjianxin1/Firefly


模型权重:

https://hf.co/YeungNLP/firefly-qwen1.5-en-7b

https://hf.co/YeungNLP/firefly-qwen1.5-en-7b-dpo-v0.1


本文将分享我们使用 Firefly 项目对 Qwen1.5-7B 进行训练的实验。我们对训练数据进行 精细化筛选 ,然后 在单张 V100 上进行 SFT 和 DPO 。经过两阶段的训练,我们的模型 在 Open LLM Leaderboard 上的表现显著优于官方的 Qwen1.5-7B-Chat、Gemma-7B-it、Vicuna-13B 等模型。 Qwen1.5-7B-Chat 高 7.12 分,比 Gemma-7B-it 高 8.8 分


通义千问 Qwen1.5 是阿里巴巴在春节前开源的大模型,支持 32K 的上下文长度,该模型本质上是 Qwen2 的 beta 版本,按照官方的说法,后续将会有 Qwen2 的正式版本 。从评测结果来看, Q wen1.5 各个尺寸的模型都显著优于同量级的 Llama2。

在 2 月份的 SuperCLUE 大模型榜单中,Qwen1.5 也有非常优秀的表现,在开源模型中处于引领者的地位。

02

DPO 简介

大模型训练主要可以分为以下三大阶段:

  1. 预训练 : 使用超大规模文本对模型进行训练,训练任务为“预测下一个 token”,训练的数据量往往需要几万亿 token。

  2. SFT (指令微调) : 使用指令数据,让模型的输出格式与人类对齐,使其具备 chat 的能力。

  3. RLHF : 使用人类反馈或者偏好数据来训练模型,使模型的输出更加符合人类的价值观或者预期行为。


在 RLHF 阶段,以往的许多大模型,例如 Llama2、 InstructGPT 等,大多采用 PPO 来对模型进行价值观对齐训练。但是采用 PPO 进行 RLHF 存在流程繁琐、显存需求多(需要将策略网络、参考网络、critic 网络、奖励模型同时加载到显存中)等问题,这导致大部分普通玩家对其敬而远之。


使用 PPO 进行 RLHF 的主要流程大致如下:

  1. 构建奖励模型的训练数据 : 对于同一个 prompt 产生多个生成结果,对这些生成结果进行人工排序,两两一组,形成 chosen 和 rejected 的 pair。每条训练数据包含三个字段,prompt、chosen、rejected。

  2. 训练奖励模型 : 使用上述数据训练奖励模型,对于每条训练数据,训练目标为最大化 chosen 与 rejected 的奖励的差值。

  3. PPO 训练 : 使用奖励模型的反馈对语言模型进行训练。


上面描述的 PPO 流程复杂且冗长,而 DPO 则绕过了奖励模型的构建,可直接使用人类偏好数据对模型进行训练,且在训练时仅需加载策略网络和参考网络,极大地节省了显存占用。训练数据包含 三个字段,prompt、chosen、rejected。


DPO 损失函数的计算过程也极具对称性,其公式如下所示:

对于上述公式,根据对数运算法则进行变换,在代码实现中,其计算过程大致如下:

  1. 计算对数概率:将 prompt 分别与 chosen 和 rejected 进行拼接,然后分别输入策略网络和参考网络,得到 4 个对数概率。

  2. 计算策略网络的 diff:策略网络的 chosen 对数概率 - rejected 对数概率。

  3. 计算参考网络的 diff: 参考网络的 chosen 对数概率 - rejected 对数概率。

  4. 计算损失函数: 策略网络的 diff - 参考网络的 diff。

03

训练设置

在 Qwen1.5-7B 的基础上,我们进行了 SFT 和 DPO 两阶段的训练, 整个训练流程仅使用一张 V100 GPU ,采用 QLoRA 技术,在所有 Linear 层都添加 adapter 以提升训练效果。 两阶段均使用英文数据进行训练 。我们与 Qwen1.5 官方的对话模板保持一致:

<|im_start|>systemYou are a helpful assistant.<|im_end|><|im_start|>userhello, who are you?<|im_end|><|im_start|>assistantI am a AI program developed by Firefly<|im_end|>

使用 Firefly 对 Qwen1.5 进行 SFT 的启动命令:

python train.py --train_args_file train_args/sft/qlora/qwen1.5-7b-sft-qlora.json

在 SFT 阶段, 实验参数设置如下:

num_epochs: 1learning_rate: 2e-4total_train_batch_size: 32max_seq_length: 2048






请到「今天看啥」查看全文