专栏名称: Coggle数据科学
Coggle全称Communication For Kaggle,专注数据科学领域竞赛相关资讯分享。
目录
相关文章推荐
51好读  ›  专栏  ›  Coggle数据科学

小白学NLP:T5模型加载与微调

Coggle数据科学  · 公众号  ·  · 2024-05-13 18:28

正文

请到「今天看啥」查看全文


unset unset T5模型介绍 unset unset

T5(Text-to-Text Transfer Transformer)是谷歌提出的一种通用的预训练语言模型,旨在统一自然语言处理任务的输入和输出。

T5模型特点

相比于以往的预训练语言模型,T5的一个显著特点是不需要添加非线性层,也不需要对模型进行额外的改动,只需在输入数据前加上任务声明前缀即可。这意味着在处理各种自然语言处理任务时,可以大大简化模型的使用和微调过程。

T5将所有任务都转化为文本到文本的形式,并使用一个统一的模型来解决。其核心理念是使用前缀任务声明及文本答案生成,这样在微调过程中就无需对模型进行改动,只需要提供相应任务的微调数据。

unset unset T5使用方法 unset unset

在T5中,输入是一个带有任务前缀声明的文本序列,这个前缀声明指定了模型应该执行的任务。输出则是相应任务的结果,以文本序列的形式呈现。这种一致的输入输出格式使得T5在处理各种自然语言处理任务时更加方便和统一。

T5模型加载与预测

  • 模型加载
import torch
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration

# load tokenizer and model 
pretrained_model = "IDEA-CCNL/Randeng-T5-784M-MultiTask-Chinese"

special_tokens = ["".format(i) for i in range(100)]
tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model,
    do_lower_case=True,
    max_length=512,
    truncation=True,
    additional_special_tokens=special_tokens,
)
config = T5Config.from_pretrained(pretrained_model)
model = T5ForConditionalGeneration.from_pretrained(pretrained_model, config=config)
model.resize_token_embeddings(len(tokenizer))
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# device = 'cpu'
model.to(device)
  • 模型预测(意图识别任务为例)
text = "意图识别任务:还有双鸭山到淮阴的汽车票吗13号的 这篇文章的类别是什么?Travel-Query/Music-Play/FilmTele-Play/Video-Play/Radio-Listen/HomeAppliance-Control/Weather-Query/Alarm-Update/Calendar-Query/TVProgram-Play/Audio-Play/Other"
encode_dict = tokenizer(text, max_length=512, padding='max_length',truncation=True)

inputs = {
  "input_ids": torch.tensor([encode_dict['input_ids']]).long().to(device),
  "attention_mask": torch.tensor([encode_dict['attention_mask']]).long().to(device),
}

# generate answer
logits = model.generate(
  input_ids = inputs['input_ids'],
  max_length=100, 
  do_sample= True
)

logits=logits[:,1:]
predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]

T5模型微调

  • 处理训练集
max_input_length = 60
max_target_length = 20

def preprocess_function(examples):
    model_inputs = tokenizer(head_prefix + examples["document"], max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
    
train_tokenized_id = train_ds.map(preprocess_function, remove_columns=train_ds.column_names)
eval_tokenized_id = eval_ds.map(preprocess_function, remove_columns=train_ds.column_names)
  • 定义模型训练参数
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

batch_size = 4
args = Seq2SeqTrainingArguments(
    "t5-finetuned",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    gradient_accumulation_steps=10,
    do_eval=True,
    evaluation_strategy="steps",
    eval_steps=50,
    num_train_epochs=5,
    save_steps=50,
    save_on_each_node=True,
    gradient_checkpointing=True,
    load_best_model_at_end=True
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tokenized_id,
    eval_dataset=eval_tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
)
trainer.train()

完整代码链接:https://github.com/coggle-club/notebooks

# 竞赛交流群 邀请函 #

△长按添加竞赛小助手

每天大模型、算法竞赛、干货资讯

40000+ 来自竞赛爱好者一起交流~








请到「今天看啥」查看全文