本文描述DeepSeek的三个模型的学习过程,其中DeepSeek-R1-Zero模型所涉及的强化学习算法,是DeepSeek最核心的部分之一会重点展示。
随着DeepSeek的火爆使用,其背后的训练技术也值得深入学习,整体DeepSeek相关的训练过程如下图所示。
其中主要涉及以下三个模型,其中DeepSeek-R1-Zero模型所涉及的强化学习算法,是DeepSeek最核心的部分之一,本次我们主要重现的也是这个部分。
是在基础模型DeepSeek-V3上进行强化学习(RL)后得到了DeepSeek-R1-Zero模型。该模型学会了如何推理、创建思维链序列,并具备自我验证和反思等能力。尽管DeepSeek-R1-Zero的学习能力令人惊叹,但它存在语言混合、可读性差等严重问题。
首先使用数千个思维链(CoT)序列示例形式的冷启动数据,在DeepSeek-V3上进行监督微调(SFT),目的是为强化学习创建一个更稳定的起点,解决DeepSeek-R1-Zero存在的问题。接着进行强化学习,并设置奖励机制,以促进语言一致性,增强在科学、编码和数学等任务上的推理能力。然后,再次进行监督微调,这次加入了非推理重点的训练示例,帮助模型保留写作、角色扮演等更多通用能力。最后,再次进行强化学习,以更好地符合人类偏好。最终得到了一个拥有6710亿参数的高性能模型。
他们基于Qwen和Llama架构,对参数在15亿 - 700亿之间的较小模型进行微调,得到了一组更轻量、更高效且推理能力更强的模型。这极大地提高了开发人员的可及性,因为许多提炼后的模型可以在他们的设备上快速运行。
强化学习(TRL):主要采用了huggingface提供的grpo_trainer方案(参考链接:https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb)
数据集:主要通过数据集gsm8k进行训练
GPU: 单张A10,显存24G
模型:Qwen2.5-0.5B-Instruct
# 基于目前最新的vllm 0.7.2进行验证
pip install vllm -U
# 基于目前最新的trl 0.15.1进行验证
pip install trl -U
import re
import torch
from modelscope import AutoTokenizer, AutoModelForCausalLM
from modelscope.msdatasets import MsDataset
from trl import GRPOConfig, GRPOTrainer
SYSTEM_PROMPT = """
You need to answer in XML format, include and , respond in the following format:
...
...
"""
XML_COT_FORMAT = """\
{reasoning}
{answer}
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("")[-1]
answer = answer.split("")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
def get_gsm8k_questions(split="train") -> MsDataset:
data = MsDataset.load('modelscope/gsm8k', subset_name='main', split=split)
data = data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}",
f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# def strict_format_reward_func(completions, **kwargs) -> list[float]:
# pattern = r"\n\n.*?\n\n\n.*?\n\n$"
# responses = [completion[0]["content"] for completion in completions]
# matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]
# return [0.5 if match else 0.0 for match in matches]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"\n.*?\n\n\n.*?\n"
responses = [completion[0]["content"] for completion in completions]
# 新增调试日志
matches = []
for idx, r in enumerate(responses):
print(f"\n--- Processing response {idx} ---")
print("Raw content:", repr(r)) # 使用 repr() 显示转义字符
match = re.fullmatch(pattern, r, re.DOTALL)
print("Match result:", bool(match))
matches.append(match)
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r".*?\s*.*?"
responses = [completion[0]["content"] for completion in completions]
matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
count -= len(text.split("\n\n")[-1]) * 0.001
if text.count("\n") == 1:
count += 0.125
count -= (len(text.split("\n")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
output_dir = "outputs/Qwen-0.5B-GRPO"
run_name = "Qwen-0.5B-GRPO-gsm8k"
training_args = GRPOConfig(
output_dir=output_dir,
run_name=run_name,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
num_generations=8,
max_prompt_length=256,
max_completion_length=200,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
log_on_each_node=False,
use_vllm=True,
vllm_gpu_memory_utilization=.2,
vllm_device="cuda:0",
report_to="none"
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=None
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func],
args=training_args,
train_dataset=dataset,
)
trainer.train()
如上面代码所示,主要涉及以下5个奖励函数
4.1. correctness_reward_func(正确性奖励函数)
检查模型的输出是否与参考答案 (answer) 完全匹配,匹配则奖励 2.0,否则 0.0。
4.2. int_reward_func(整数检测奖励函数)
检查模型输出是否是纯数字(整数),是则奖励 0.5,否则 0.0。
4.3. strict_format_reward_func(严格格式奖励函数)
严格格式奖励,必须完全匹配
...
...
,包括其中的换行符,都必须满足格式,如果符合格式的奖励
0.5
,否则
0.0
。
4.4. soft_format_reward_func(宽松格式奖励函数)
允许更灵活的格式,只要包含
...
和
...
,即奖励 0.5,对比严格模式更加宽松
4.5. count_xml,xmlcount_reward_func(XML 结构评分函数)
计算模型输出 XML 结构的完整度,并给予相应奖励。奖励规则:
检查 XML 结构完整度:
每个正确的标签匹配增加 0.125 奖励:
\\n:+0.125
\\n:+0.125
\\n:+0.125
:+0.125
考虑额外文本的惩罚:
如果 后面有多余的内容,则减少奖励 0.001 × 额外字符数
核心参数说明如下:
1.
gradient_accumulation_steps=4:
每进行4次的前向传播和反向传播后,才会执行一次权重更新;
2.
max_completion_length=200: 表示限制模型返回最大长度200;
3.
save_steps=100:
表示每运行100步才保存一次checkpoint;
gsm8k数据集一共接近8000条数据,每4次会更新一次,则需要更新2000次,每100步保存一次,则需要生成20个checkpoint。
通过python train.py > train.log运行代码,通过tail -f train.log进行实时日志查看,最后整体效果如下图所示,最后有效数据1868个,运行时间是2:25:25。
GRPO Trainer会记录很多训练过程中的指标,主要包括在:
其中我们主要关注以下两个奖励指标:
-
准确性奖励:
基于响应的正确性(对应correctness_reward_func)
-
格式奖励: