专栏名称: Hugging Face
The AI community building the future.
目录
相关文章推荐
OSC开源社区  ·  Bun ... ·  昨天  
程序猿  ·  41岁DeepMind天才科学家去世:长期受 ... ·  昨天  
OSC开源社区  ·  2024: 大模型背景下知识图谱的理性回归 ·  4 天前  
OSC开源社区  ·  升级到Svelte ... ·  5 天前  
51好读  ›  专栏  ›  Hugging Face

社区供稿 | Google Gemma 2B 微调实战(IT 科技新闻标题生成)

Hugging Face  · 公众号  ·  · 2024-04-03 21:00

正文

Google Gemma 2B 微调实战(IT科技新闻标题生成)

本文我将使用 Google 的 Gemma-2b 模型来微调一个基于IT科技新闻正文来生成对应标题的模型。并且我将介绍如何使用高度集成的训练框架来进行快速微调。

开始前

为了尽可能简化整个流程,我将使用 linux-cn 数据集 [1] 作为本次训练任务的训练数据。

模型选择使用 Gemma-2b [2] ,在目前这个任务中 2b 级别的参数模型已经完全能满足当前的需求,当然你也可以尝试使用 7b 的模型。

我们在这里将直接使用 LLaMA-Factory [3] 训练框架来直接完成监督微调部分工作。当然该框架不仅支持监督微调(SFT)也支持预训练(PT)、奖励模型(RM)以及 PPO/DPO 的训练。

数据整理

linux-cn 数据集本身已经进行了数据的清洗和格式化,这一步我们只需要把我们需要的字段提取出后来后根据一定格式转换为 LLaMA-Factory 监督微调格式即可。

在本任务中,我们只需要数据集中的“title”和“content”两个字段即可。而 LLaMA-Factory 监督微调格式是如下格式的json文件。

[
  {
    "instruction""What are the three primary colors?",
    "input""",
    "output""The three primary colors are red, blue, and yellow. These colors are called primary because they cannot be created by mixing other colors and all other colors can be made by combining them in various proportions. In the additive color system, used for light, the primary colors are red, green, and blue (RGB).",
  },
...
]

因为我们选择使用的是预训练模型,所以我们还需要指定一个 prompt template。指定 prompt template 的一个好处是你如果希望同时训练多个不同类型的任务,这样可以保证不同任务之间不会相互干扰。

完整代码如下:

import json

result = []

prompt_template = """Generate a title for the article:

{content}

---
Title:
"""

with open('archve.jsonl''r'as f:
    for line in f:
        p = json.loads(line)
        result.append({
            "instruction": prompt_template.replace("{content}", p['content']),
            "input""",
            "output": p['title']
        })

with open('itnews_data.json''w'as f:
    json.dump(result, f,ensure_ascii=False, indent=4)

完成这一步后,我们就可以开始训练我们的模型了。但往往耗费时间最长以及最头疼的也是数据收集和数据整理这一部分。

模型微调

首先你需要保证 LLaMA-Factory 框架已经在你本地已经 ready 了。即你已经下载了该项目并且已经进行了项目的安装。

具体如何安装你可以查看该项目的 README,本文不再过多赘述。

首先我们需要将数据集移动到框架的 data 目录中,然后在 dataset_info.json 中添加我们自定义的数据集。

以下是本文实例所添加的数据集信息:

  "itnews": {
    "file_name""itnews_data.json",
  },

当然不同类型的任务该框架会有不同的数据集格式要求,你可以参考项目中 dataset_info.json README [4]

然后我们只需要执行如下命令就可以开始微调了,本文是在单张A100(80G)上进行的微调。

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --do_train True \
    --model_name_or_path google/gemma-2b \
    --finetuning_type lora \
    --template default \
    --dataset itnews \
    --use_unsloth \
    --cutoff_len 8192 \
    --learning_rate 5e-05 \
    --num_train_epochs 10.0 \
    --max_samples 10000 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 10 \
    --save_steps 100 \
    --eval_steps 100 \
    --evaluation_strategy steps \
    --warmup_steps 0 \
    --output_dir saves/Gemma-2B/lora/train_v1 \
    --bf16 True \
    --lora_rank 8 \
    --lora_dropout 0.1 \
    --lora_target q_proj,v_proj \
    --val_size 0.1 \
    --load_best_model_at_end True \
    --plot_loss True \
    --report_to "tensorboard"

在这里我需要对其中的几个参数进行简短的介绍:

--stage 即任务类型,在这里我们本文做的是监督微调所以是 sft,如果是其他任务你需要指定不同的类型。

--dataset 即数据集,这里的名称就是我们在 dataset_info.json 文件中指定的数据集名称。

--use_unsloth 这是一个训练加速器,官方宣称在 Gemma 7b 上拥有 2.4x 的加速,并且节省超一半的显存。在使用这个之前你需要按照 官方文档 [5] 进行安装。

--cutoff_len 文本令牌化后输入到模型的截止长度,因为本文使用的 Gemma 2b 模型,它的最大长度是 8192 ,所以在这里我设置的是 8192。但请记住更长的上下文也需要更多的 GPU 显存!

--max_samples 设置数据集加载的最大条数。本参数主要用作调试目的时非常好用,尤其是在你不确定 cutoff_len batch_size 的时候,你可以加载很小的一部分数据进行测试,然后查看你显存的使用情况。

--learning_rate --num_train_epochs 学习率和训练周期,这是一个经验值,一般通过查看模型的 loss 来调整,当然在 LLM 模型训练中,本参数主要以模型是否符合任务需求而决定,也就是说完美的 loss 可能并不满足需求。

--per_device_train_batch_size --per_device_eval_batch_size --gradient_accumulation_steps 这三个参数需要根据你的显存大小以及是否使用多个GPU等条件进行不同的调整。

--output_dir 模型保存的目录。

更多的参数解释可以查看 项目说明 [6] ,以及 transformers Trainer 说明 [7]

模型使用

在这里我们可以直接使用 transformers 来执行。

from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig

peft_model_id = "checkpoint-2000"
model = AutoModelForCausalLM.from_pretrained(peft_model_id,device_map="cuda")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

input_text = """
Generate a title for the article:

{content}

---
Title:
"""
 # 固定格式
encoding = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**encoding,max_length=8192,temperature=0.2,do_sample=True)
generated_ids = outputs[:, encoding.input_ids.shape[1]:]
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_texts[0])

我通过使用我自己的一篇差不多 5000 tokens 关于 微服务的文章 [8]







请到「今天看啥」查看全文