if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if __name__ == "__main__": # The vocab.txt was downlowned from https://huggingface.co/google-bert/bert-base-chinese/blob/main/vocab.txt . save_tokenizer( "./vocab.txt", model_name=train_args.tokenizer_name, bos_token=train_args.bos_token, eos_token=train_args.eos_token, bot_token=train_args.bot_token, user_token=train_args.user_token, )
@dataclass class TrainArguments: dataset_name: str = "chichat_dataset" bos_token: str = "" eos_token: str = "" bot_token: str = "" user_token: str = "" bos_token_id: int = 21128 eos_token_id: int = 21129 bot_token_id: int = 21130 user_token_id: int = 21131 ignore_index: int = -100
构建数据集
from transformers import BertTokenizerFast, AutoTokenizer from datasets import Dataset, DatasetDict, load_dataset
import pickle import os import re from tqdm import tqdm # 上面的配置类 from config import train_args from log import logger
def get_dataset(source_dataset, tokenizer, args): """ The format we need is `utterance1utterance2utterance3utterance4` """ dialogues = []
for example in tqdm(source_dataset["train"]): record = example["instruction"] + example["output"] utterances = re.split(r"(Human:|Assistant:)", record)
utterances = [ x.strip() for x in utterances if x.strip() not in ["Human:", "Assistant:", ""] ] dialogues.append(utterances)
logger.info(f"There are {len(dialogues)} dialogues.")
print(dialogues[0])
conversation_list = []
for utterances in tqdm(dialogues): # 每个对话以BOS开头 input_ids = [args.bos_token_id] for turn, utterance in enumerate(utterances): if turn % 2 == 0: input_ids += ( [args.user_token_id] + tokenizer.encode(utterance, add_special_tokens=False) + [args.eos_token_id] ) else: input_ids += ( [args.bot_token_id] + tokenizer.encode(utterance, add_special_tokens=False) + [args.eos_token_id] ) # 不能超过model_max_length if len(input_ids) <= tokenizer.model_max_length: conversation_list.append(input_ids)
$ python .\data_process.py Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 831036/831036 [00:40<00:00, 20539.42it/s] 2024-03-14 19:41:39 - INFO - root - There are 831036 dialogues. ['你好,你能帮我解答一个问题吗?', '当然,请问有什么问题?', '我想了解人工智能的未来发展方向,你有什么想法吗?', '人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更 加智能的机器人。此外,人工智能还可以帮助解决许多现实世界的问题,例如自动化和改善医疗保健等领域。', '听起来很不错。人工智能可能在哪些方面面临挑战呢?', '人工智能面临的挑战包括数据隐私、安全和道德方面的 问题,以及影响就业机会的自动化等问题。此外,人工智能可能会带来不平等和歧视风险,这也是需要关注的问题。'] 1%|█▌ | 8484/831036 [00:07<12:14, 1119.32it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1042 > 1024). Running this sequence through the model will result in indexing errors 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 831036/831036 [13:14<00:00, 1045.42it/s] Saving the dataset (3/3 shards): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 787723/787723 [00:05<00:00, 148342.40 examples/s] Saving the dataset (1/1 shards): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41460/41460 [00:00<00:00, 296245.05 examples/s] DatasetDict({ train: Dataset({ features: ['input_ids'], num_rows: 787723 }) valid: Dataset({ features: ['input_ids'], num_rows: 41460 }) })