专栏名称: Datawhale
一个专注于AI领域的开源组织,汇聚了众多顶尖院校和知名企业的优秀学习者,聚集了一群有开源精神和探索精神的团队成员。愿景-for the learner,和学习者一起成长。
目录
相关文章推荐
51好读  ›  专栏  ›  Datawhale

DeepSeek关键RL算法GRPO,手把手教你从头跑通!

Datawhale  · 公众号  ·  · 2025-03-02 16:11

正文

Datawhale分享

作者:Andriy Burkov,编译:机器之心

GRPO(Group Relative Policy Optimization)是 DeepSeek-R1 成功的基础技术之一

简单来说,GRPO 算法丢弃了 critic model,放弃了价值函数近似,转而通过组内样本的相对比较来计算策略梯度,从而有效降低了训练的不稳定性,同时提高了学习效率。

既然 GRPO 如此有效,那么,你知道如何从头开始实现 GRPO 吗?

近日,AI 工程师和技术作家 Andriy Burkov 发布了一份「从头开始写 GRPO 代码」的教程,其中介绍了如何基于 Qwen2.5-1.5B-Instruct 模型构建一个使用 GRPO 的分布式强化学习流程。

不过,在我们深入这份教程之前,先简单介绍一下它的作者。Andriy Burkov 算得上是 AI 领域的一位著名科普作家,在加拿大拉瓦尔大学取得了计算机科学博士学位,还曾发表过两本颇受欢迎的 AI 主题著作:《100 页语言模型书》和《100 页机器学习书》;书中一步步详实地介绍了相关概念,并附带了简明的实现代码。

接下来我们就来看看这份 GRPO 从头实现教程吧。

image.png

教程地址:https://github.com/aburkov/theLMbook/blob/main/GRPO_From_Scratch_Multi_GPU_DataParallel_Qwen_2_5_1_5B_Instruct.ipynb

从头编写 GRPO 代码
使用 Qwen2.5-1.5B-Instruct 的分布式实现

本教程将展示如何使用 GRPO 方法构建分布式强化学习(RL)流程,从而可以针对数学、逻辑和编程任务对语言模型进行微调。

首先需要明确,这些任务都存在一个唯一且正确的 ground truth 答案,可通过简单的字符串比较轻松加以验证。

GRPO 的发明者是 DeepSeek,最早是被用于微调 DeepSeek 的 R1 和 R1-Zero 模型 —— 它们可通过学习生成思维链(CoT)来更好地解决数学和逻辑问题。

本教程的目标是将通用语言模型 Qwen2.5-1.5B-Instruct 转换为数学问题求解器。我们将从头开始编写 GRPO 代码,然后将其与几个流行的库和工具集成起来,以实现分布式训练管道流程,包括:

  • PyTorch:用于张量运算和分布式训练。
  • Hugging Face Transformers:用于加载预训练的语言模型和 tokenizer。
  • FlashAttention2:优化的注意力机制,有助于减少内存使用量并提高训练速度。
  • Weights & Biases (wandb):用于实验跟踪、可视化和模型版本控制。

本教程分为几个部分。首先是基本设置和导入,然后是数据格式化和答案提取、数据集准备、评估函数、奖励函数、训练设置和执行,最后加载和测试模型。此过程中,我们将从头实现 GRPO 算法。

Part 1:基础设置和导入

首先是安装并导入所有必要的模块。下面是导入库的一段代码截图。

image.png
部分代码截图。完整代码块参见 GitHub。

运行上述代码(参考项目完整代码),可以执行以下任务:

  • 设置随机种子:set_random_seed 函数通过为 Python 的随机模块、NumPy 和 PyTorch 设置种子,确保可复现性;
  • 环境变量配置:设置 WANDB_API_KEY 和 WANDB_PROJECT 环境变量,以启用与 Weights & Biases 的实验跟踪;
  • 导入必要的库,包括 random、copy、re、torch 等等。

Part 2:数据格式以及答案提取

接下来,项目定义了数据格式,以及模型如何从输出和数据集中提取答案段落。

为了确保模型输出格式一致,项目还定义了一个系统提示。该提示指示模型生成包含 < reasoning > 和 < answer > 标签的输出。这一步通过两个函数完成:

  • extract_answer_from_model_output:此函数获取模型的输出文本,并提取 < answer > 标签内的内容;
  • extract_answer_from_dataset:此函数从 GSM8K 数据集中提取预期答案,该数据集使用 “####” 分隔符来分隔答案:

image.png
部分代码截图。完整代码块参见 GitHub。

Part 3:数据准备

该项目使用 GSM8K 数据集进行训练。项目使用了该数据集中的示例来训练模型,基于强化学习(RL)训练范式,让模型生成多个问题解答样本,之后作者将这些解答与 GSM8K 示例中的标准答案进行对比,如果匹配,就为 RL 算法(GRPO)提供高奖励,然后更新模型权重,以增加模型下次获得高奖励的可能性。

实验过程是这样的。首先从 Hugging Face 加载数据集,然后格式化每个示例,包括系统提示和用户提示。这段实现代码中还定义了两个辅助函数:prepare_dataset 以及 build_prompt。

image.png
部分代码截图。完整代码块参见 GitHub。

Part 4:评估函数

评估对于跟踪模型的进展至关重要。因此作者定义了一些函数,从而可以在一组示例上对模型进行评估。该项目的评估函数执行以下任务:

  • token 化提示并生成响应:模型的输出是在 token 化提示的基础上生成的。
  • 提取预测答案:从生成的响应中提取答案。
  • 将预测答案与预期答案进行比较:这种比较是通过精确匹配以及数值等价检查来完成的。

在这段代码中,两个辅助函数 _extract_last_number 和 _extract_single_number 被用来从文本中提取数字。评估函数 evaluate_model 使用这些辅助函数来确定预测答案是否正确:

image.png
部分代码截图。完整代码块参见 GitHub。

Part 5:奖励函数

在强化学习中,奖励函数是必不可缺的,作者定义了两个奖励函数:

correctness_reward:这个函数根据生成的答案是否正确来分配奖励。采用两种方式:精确的字符串匹配和数值等价检查,将模型输出的答案与预期答案进行比较。完全匹配会获得更高的奖励(2.0),而基于数值等价的匹配会获得较小的奖励(1.5)。

format_reward:这个函数鼓励模型遵循所需的类似 XML 的输出格式。它为生成文本中存在 < reasoning>、、 标签提供小额奖励。






请到「今天看啥」查看全文