专栏名称: 学姐带你玩AI
这里有人工智能前沿信息、算法技术交流、机器学习/深度学习经验分享、AI大赛解析、大厂大咖算法面试分享、人工智能论文技巧、AI环境工具库教程等……学姐带你玩转AI!
目录
相关文章推荐
中国证券报  ·  实探 | 宇树机器人带着机器狗,来了! ·  昨天  
上海证券报  ·  推动旅游业发展,上海最新部署 ·  2 天前  
上海证券报  ·  600863,重组,今起复牌 ·  3 天前  
中国证券报  ·  华为大动作!成立新公司 ·  3 天前  
上海证券报  ·  占比上升!人民币大消息 ·  3 天前  
51好读  ›  专栏  ›  学姐带你玩AI

使用Kaggle Docs数据微调Gemma模型

学姐带你玩AI  · 公众号  ·  · 2024-05-01 12:51

正文

来源:投稿  作者:echo
编辑:学姐

Introduction

This notebook will demonstrate three things:

  • How to fine-tune Gemma model using LoRA
  • Creation of a specialised class to query about Kaggle features
  • Some results of querying about Kaggle Docs

什么是Gemma

Gemma是一个轻量级的源生成人工智能模型集合,主要供开发人员和研究人员使用。Gemma由谷歌DeepMind研究实验室创建,该实验室也开发了Gemini, Gemma有几个版本,具有2B和7B参数,如下所示:

什么是LoRA?

LoRA是Low-Rank Adaptation的缩写。它是一种通过冻结大语言模型的权重并注入可训练的秩分解矩阵来微调大语言模型的方法。因此,微调过程中可训练参数的数量将大大减少。根据LoRA论文,这个数字减少了10,000倍,计算资源大小减少了3倍。

要对LoRA进行微调,我们将遵循以下步骤:

  1. 安装库
  2. 加载并处理数据以进行微调
  3. 初始化Gemma因果语言模型(Gemma causal LM)的代码
  4. 进行微调
  5. 使用用于微调的数据中的问题和其他问题测试微调模型

Install packages

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

Import packages

import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

config

class Config:
    seed = 42
    dataset_path = "/kaggle/input/kaggle-docs/questions_answers"
    preset = "gemma_2b_en" # name of pretrained Gemma
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training, x 2 as two GPUs
    epochs = 15 # number of epochs to train
keras.utils.set_random_seed(Config.seed)

Load data

df = pd.read_csv(f"{Config.dataset_path}/data.csv")
df.head()

为了方便起见,我们为QA创建以下模板

template = "\n\nCategory:\nkaggle-{Category}\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"
# 定义了一个文本模板,包含了一些占位符 `{Category}`, `{Question}`, `{Answer}`,分别代表类别、问题和答案。这些占位符将在后面的步骤中被实际的值替换。

df["prompt"] = df.apply(lambda row: template.format(Category=row.Category,
                                                             Question=row.Question,
                                                             Answer=row.Answer), axis=1)
# 将DataFrame中的每一行应用到一个lambda函数上。lambda函数接受行数据作为参数,并使用`template.format()`方法将该行数据填充到模板中,然后将结果存储在一个新列`prompt`中。

data = df.prompt.tolist()
def colorize_text(text):
    # 使用 zip 函数将要替换的单词和相应的颜色一一对应起来
    for word, color in zip(["Category""Question""Answer"], ["blue""red""green"]):
        # 遍历每个单词和对应的颜色,将文本中匹配到的部分替换为带颜色的 HTML 标记
        text = text.replace(f"\n\n{word}:", f"\n\n**{word}:**")
    return text

Specialized class to query Gemma

We define a specialized class to query Gemma.

Initialize the code for Gemma Causal LM¶

gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_causal_lm.summary()

Define the specialized class

class GemmaQA:
    def __init__(self, max_length=512):
        self.max_length = max_length
        self.prompt = template
        self.gemma_causal_lm = gemma_causal_lm
        
    def query(self, category, question):
        response = self.gemma_causal_lm.generate(
            self.prompt.format(
                Category=category,
                Question=question,
                Answer=""), 
            max_length=self.max_length)
        display(Markdown(colorize_text(response)))
        

Gemma preprocessor

这个Gemmma预处理层接受一批字符串作为输入,并以 (x, y, sample_weight) 的格式返回输出,其中 y 标签是 x 序列中的下一个标记的标识符。

通过下面的代码,我们可以看到,经过预处理之后,数据的形状为 (num_samples, sequence_length)

这段代码的功能是将输入的文本序列切分为固定长度的序列,并将每个序列中的每个标记作为 x,其下一个标记作为 y。最终返回的是一组 x 序列、对应的 y 序列以及样本权重。

x, y, sample_weight = gemma_causal_lm.preprocessor(data[0:2])
print(x, y)

Enable LoRA for the model and set the LoRA rank to 4.

gemma_causal_lm.backbone.enable_lora(rank=4)
gemma_causal_lm.summary()

Run the training sequence¶

gemma_causal_lm.preprocessor.sequence_length = Config.sequence_length 

# Compile the model with loss, optimizer, and metric
gemma_causal_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=8e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_causal_lm.fit(data, epochs=Config.epochs, batch_size=Config.batch_size)

Test the fine-tuned model

gemma_qa = GemmaQA()

sample1

row = df.iloc[0]
gemma_qa.query(row.Category,row.Question)

sample2

row = df.iloc[15]
gemma_qa.query(row.Category,row.Question)
category = "notebook"
question = "How to run a notebook?"
gemma_qa.query(category,question)
category = "competitions"
question = "What is a code competition?"
gemma_qa.query(category,question)

以上演示了如何使用LoRA对Gemma模型进行微调。我们还创建了一个类来运行对Gemma模型的查询,并使用来自现有训练数据的一些示例以及一些新的,未见过的问题对其进行测试。

关注“







请到「今天看啥」查看全文