专栏名称: 学姐带你玩AI
这里有人工智能前沿信息、算法技术交流、机器学习/深度学习经验分享、AI大赛解析、大厂大咖算法面试分享、人工智能论文技巧、AI环境工具库教程等……学姐带你玩转AI!
目录
51好读  ›  专栏  ›  学姐带你玩AI

Transformer从菜鸟到新手(四)

学姐带你玩AI  · 公众号  ·  · 2024-05-08 18:35

正文

来源:投稿  作者:175
编辑:学姐

引言

上篇文章 完成了Transformer剩下组件的编写,因此本文就可以开始训练。

本文主要介绍训练时要做的一些事情,包括定义损失函数、学习率调整、优化器等。

下篇文章会探讨如何在多GPU上进行并行训练,加速训练过程。

数据集简介

从网上找到一份中英翻译wmt数据集,数据格式如下:

[
    ["english sentence""中文语句"], 
    ["english sentence""中文语句"]
]

其中训练、验证、测试集的样本数分别为:176943、25278、50556。

下载地址:https://download.csdn.net/download/yjw123456/88694140 (固定只需要5积分)

def build_dataframe_from_json(
    json_path: str,
    source_tokenizer: spm.SentencePieceProcessor = None,
    target_tokenizer: spm.SentencePieceProcessor = None,
) -> pd.DataFrame:
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    df = pd.DataFrame(data, columns=["source""target"])

    def _source_vectorize(text: str) -> list[str]:
        return source_tokenizer.EncodeAsIds(text, add_bos=True, add_eos=True)

    def _target_vectorize(text: str) -> list[str]:
        return target_tokenizer.EncodeAsIds(text, add_bos=True, add_eos=True)

    tqdm.pandas()

    if source_tokenizer:
        df["source_indices"] = df.source.progress_apply(lambda x: _source_vectorize(x))
    if target_tokenizer:
        df["target_indices"] = df.target.progress_apply(lambda x: _target_vectorize(x))

    return df

传入原文的目的是计算BLEU分数时方便一点,当然也可以对编码后的索引反向解码成原文。

剩下的事情是通过数据加载器来加载数据集,相关代码如下:

assert os.path.exists(
    train_args.src_tokenizer_file
), "should first run train_tokenizer.py to train the tokenizer"
assert os.path.exists(
    train_args.tgt_tokenizer_path
), "should first run train_tokenizer.py to train the tokenizer"
source_tokenizer = spm.SentencePieceProcessor(
    model_file=train_args.src_tokenizer_file
)
target_tokenizer = spm.SentencePieceProcessor(
    model_file=train_args.tgt_tokenizer_path
)

if train_args.only_test:
    train_args.use_wandb = False

if train_args.cuda:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

print(f"source tokenizer size: {source_tokenizer.vocab_size()}")
print(f"target tokenizer size: {target_tokenizer.vocab_size()}")

set_random_seed(12345)

train_dataframe_path = os.path.join(
    train_args.save_dir, train_args.dataframe_file.format("train")
)
test_dataframe_path = os.path.join(
    train_args.save_dir, train_args.dataframe_file.format("test")
)
valid_dataframe_path = os.path.join(
    train_args.save_dir, train_args.dataframe_file.format("dev")
)

if os.path.exists(train_dataframe_path) and train_args.use_dataframe_cache:
    train_df, test_df, valid_df = (
        pd.read_pickle(train_dataframe_path),
        pd.read_pickle(test_dataframe_path),
        pd.read_pickle(valid_dataframe_path),
    )
    print("Loads cached dataframes.")
else:
    print("Create new dataframes.")

    valid_df = build_dataframe_from_json(
        f"{train_args.dataset_path}/dev.json", source_tokenizer, target_tokenizer
    )
    print("Create valid dataframe")
    test_df = build_dataframe_from_json(
        f"{train_args.dataset_path}/test.json", source_tokenizer, target_tokenizer
    )
    print("Create train dataframe")
    train_df = build_dataframe_from_json(
        f"{train_args.dataset_path}/train.json", source_tokenizer, target_tokenizer
    )
    print("Create test dataframe")

    train_df.to_pickle(train_dataframe_path)
    test_df.to_pickle(test_dataframe_path)
    valid_df.to_pickle(valid_dataframe_path)

pad_idx = model_args.pad_idx

train_dataset = NMTDataset(train_df, pad_idx)
valid_dataset = NMTDataset(valid_df, pad_idx)
test_dataset = NMTDataset(test_df, pad_idx)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=train_args.batch_size,
    collate_fn=train_dataset.collate_fn,
)
valid_dataloader = DataLoader(
    valid_dataset,
    shuffle=False,
    batch_size=train_args.batch_size,
    collate_fn=valid_dataset.collate_fn,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=train_args.batch_size,
    collate_fn=test_dataset.collate_fn,
)

数据处理好之后我们就可以开始训练了。

模型训练

标签平滑

Transformer的训练过程中用到了标签平滑(label smoothing)技术,目的是防止模型训练时过于自信地预测标签,改善泛化能力不足的问题。

简单来说就是降低原来one-hot形式中目标类别(对应1,即100%)的概率,拿出来分给其他类别。

以下内容摘自参考8的论文,不感兴趣可以直接跳过。

因此需要一种机制让模型不那么自信,虽然与最大化训练标签的对数似然有点相违背,但这确实对模型进行正则化使其更具适应性。

这样,LSR可以看成是将单个交叉熵损失H ( q , p )替换为H ( q , p )和H ( u , p )的两个损失的加权和。

在训练时,如果模型非常确信的预测出真实标签分布,即H ( q , p )接近0,但H ( u , p )会急剧增大,因此基于标签平滑,我们可以防止模型预测地太过自信。

第二项损失惩罚了预测标签分布p 和先验分布u 之间的偏差,注意,这种偏差可以等价地通过KL散度来捕捉。为什么这么说?

而分布u 的熵H ( u ) 是固定的,所以H ( u , p ) 只有KL散度有关。

当u 是均匀分布时,H ( u , p ) 衡量了预测分布p 与均匀分布的不相似程度,这也可以通过负熵− H ( p ) 来衡量(但并非等价)。

PyTorch在1.10之后就支持标签平滑:

nn.CrossEntropyLoss(ignore_index = pad_idx, reduction="sum", label_smoothing = 0.1 )

通过传入 ignore_index pad index、reduction='sum' 和设置 label_smoothing 值来使用。

但是光这还不够,当我们使用 CrossEntropyLoss 时,我们需要拉平模型的输出和标签标记索引,所以我们定义如下损失类来包装 CrossEntropyLoss






    
class LabelSmoothingLoss(nn.Module):
    def __init__(self, laabel_smoothing: float =0.0, pad_idx: int = 0) -> None:
        super().__init__()
        self.loss_func = nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=label_smoothing)
    
    def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
        vocab_size = logits.shape[-1]
        logits = logits.reshape(-1, vocab_size)
        labels = labels.reshape(-1).long()
        return self.loss_func(logits, labels)

注意,实际上本文用到的数据集使用标签平滑效果反而不好。因此训练过程中并未使用。

学习率&优化器

使用Adam优化器, 。我们可以这样实现:

from torch.optim import Adam

optimizer = Adam(model.parameters(),
                 betas = (0.9, 0.98),
                 eps = 1e-9)

并使用warmup策略调整学习率:

使用固定步数warmup_steps \text{warmup_steps}warmup_steps 先使学习率线性增长(预热) ,而后随着step_num \text{step_num}step_num的增加以step_num \text{step_num}step_num的平方根成比例 逐渐减小学习率

我们可以封装Adam优化器,并支持预热和学习率衰减。

class WarmupScheduler(_LRScheduler):
    def __init__(
        self,
        optimizer: Optimizer,
        warmup_steps: int,
        d_model: int,
        factor: float = 1.0,
        last_epoch: int = -1,
        verbose: bool = False,
    ) -> None:
        """

        Args:
            optimizer (Optimizer): Wrapped optimizer.
            warmup_steps (int): warmup steps
            d_model (int): dimension of embeddings.
            last_epoch (int, optional): the index of last epoch. Defaults to -1.
            verbose (bool, optional): if True, prints a message to stdout for each update. Defaults to False.

        "
""
        self.warmup_steps = warmup_steps
        self.d_model = d_model
        self.num_parm_groups = len(optimizer.param_groups)
        self.factor = factor
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self) -> list[float]:
        lr = (
            self.factor
            * self.d_model**-0.5
            * min(
                self._step_count**-0.5, self._step_count * self.warmup_steps**-1.5
            )
        )
        return [lr] * self.num_parm_groups

这里通过继承 LRScheduler 来实现,并且通过 factor 参数控制学习率的大小,小数据集可以尝试设置成0.5。

我们可以画出学习率变化的趋势图:

关注上图的橙线,可以看到,学习率确实是从0开始逐渐增加,直到4000步后,开始逐渐下降。

为什么这个公式可以达到这个效果?好像其中包含了一个IF-ELSE似的。为了直观的理解,我们把这个公式重写成:

这样是不是就大概能看出来了:

warmup_step=4000 时, warmup_steps ** 1.5=252982.2128 。当训练步数 step_num 小于热身步数时,函数内右项一直小于左项,但随着训练步数的增加而线性增加;

当训练步数到达热身步数 warmup_steps 时, min 函数内的两项相等;

当训练步数大于热身步数,函数内左项小于右项,并且随着训练步数的增加而(非线性)减少;

这样就实现了我们上图看到的效果。从公式还以看到一点,就是 模型的嵌入大小 d_model 越大,或者 warmup_steps 越大,学习率的峰值就越小 ,而且warmup_steps越大,学习率开始增加的越缓慢。

训练分词器

正如上文所述,我们使用sentencepiece工具包进行分词,首先将中英文语句分别读入内存。

def get_mt_pairs(data_dir: str, splits=["train""dev""test"]):
    english_sentences = []
    chinese_sentences = []

    """
    json content:
    [["
english sentence", "中文语句"], ["english sentence", "中文语句"]]
    "
""
    for split in splits:
        with open(f"{data_dir}/{split}.json""r", encoding="utf-8") as f:
            data = json.load(f)
            for pair in data:
                english_sentences.append(pair[0] + "\n")
                chinese_sentences.append(pair[1] + "\n")

    assert len(chinese_sentences) == len(english_sentences)

    print(f"the total number of sentences: {len(chinese_sentences)}")

    return chinese_sentences, english_sentences

接着定义一个训练函数,这里用多进程同时训练:

def train_tokenizer(
    source_corpus_path: str,
    target_corpus_path: str,
    source_vocab_size: int,
    target_vocab_size: int,
    source_character_coverage: float = 1.0,
    target_character_coverage: float = 0.9995,
) -> None:
    with ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(
                train_sentencepice_bpe,
                source_corpus_path,
                "model_storage/source",
                source_vocab_size,
                source_character_coverage,
            ),
            executor.submit(
                train_sentencepice_bpe,
                target_corpus_path,
                "model_storage/target",
                target_vocab_size,
                target_character_coverage,
            ),
        ]

        for future in futures:
            future.result()

    sp = spm.SentencePieceProcessor()

    source_text = """
        Tesla is recalling nearly all 2 million of its cars on US roads to limit the use of its 
        Autopilot feature following a two-year probe by US safety regulators of roughly 1,000 crashes 
        in which the feature was engaged. The limitations on Autopilot serve as a blow to Tesla’s efforts 
        to market its vehicles to buyers willing to pay extra to have their cars do the driving for them.
        "
""

    sp.load("model_storage/source.model")
    print(sp.encode_as_pieces(source_text))
    ids = sp.encode_as_ids(source_text)
    print(ids)
    print(sp.decode_ids(ids))

    target_text = """
        新华社北京1月2日电(记者丁雅雯、李唐宁)2024年元旦假期,旅游消费十分火爆。旅游平台数据显示,旅游相关产品订单量大幅增长,“异地跨年”“南北互跨”成关键词。
        业内人士认为,元旦假期旅游“开门红”彰显消费潜力,预计2024年旅游消费有望保持上升势头。
    "
""

    sp.load("model_storage/target.model")
    print(sp.encode_as_pieces(target_text))
    ids = sp.encode_as_ids(target_text)
    print(ids)
    print(sp.decode_ids(ids))

最后执行训练代码:

if __name__ == "__main__":
    make_dirs(train_args.save_dir)

    chinese_sentences, english_sentences = get_mt_pairs(
        data_dir=train_args.dataset_path, splits=["train""dev""test"]
    )

    with open(f"{train_args.dataset_path}/corpus.ch""w", encoding="utf-8") as f:
        f.writelines(chinese_sentences)

    with open(f"{train_args.dataset_path}/corpus.en""w", encoding="utf-8") as f:
        f.writelines(english_sentences)

    train_tokenizer(
        f"{train_args.dataset_path}/corpus.en",
        f"{train_args.dataset_path}/corpus.ch",
        source_vocab_size=model_args.source_vocab_size,
        target_vocab_size=model_args.target_vocab_size,
    )

['▁Tesla''▁is''▁recalling''▁nearly''▁all''▁2''▁million''▁of''▁its''▁cars''▁on''▁US''▁roads''▁to''▁limit''▁the''▁use''▁of''▁its''▁Aut''op''ilot''▁feature''▁following''▁a'
'▁two''-''year''▁probe''▁by''▁US''▁safety''▁regulators''▁of''▁roughly''▁1,000''▁crashes''▁in''▁which''▁the''▁feature''▁was''▁engaged''.''▁The''▁limitations''▁on''▁Aut''op'
'ilot''▁serve''▁as''▁a''▁blow''▁to''▁Tesla''’''s''▁efforts''▁to''▁market''▁its''▁vehicles''▁to''▁buyers''▁willing''▁to''▁pay''▁extra''▁to''▁have''▁their''▁cars''▁do''▁the''▁driving''▁for''▁them''.']
[22941, 59, 20252, 2225, 255, 216, 1132, 34, 192, 5944, 81, 247, 6980, 31, 3086, 10, 894, 34, 192, 5296, 177, 31299, 6959, 2425, 6, 600, 31847, 2541, 22423, 144, 247, 3474, 4270, 34, 2665, 8980, 23659, 26, 257, 10, 6959, 219, 5037, 31843, 99, 10725, 81, 5296, 177, 31299, 3343, 98, 6, 6296, 31, 22941, 31849, 31827, 1369, 31, 404, 192, 6287, 31, 10106, 2207, 31, 1129, 2904, 31, 147, 193, 5944, 295, 10, 4253, 75, 437, 31843]
Tesla is recalling nearly all 2 million of its cars on US roads to limit the use of its Autopilot feature following a two-year probe by US safety regulators of roughly 1,000 crashes in which the feature was engaged. The limitations on Autopilot serve as a blow to Tesla’s efforts to market its vehicles to buyers willing to pay extra to have their cars do the driving for them.
['▁新''华''社''北京''1''月''2''日''电''(''记者''丁''雅''雯''、''李''唐''宁' ')''20''24''年''元''旦''假期'',''旅游''消费''十分''火''爆''。''旅游''平台''
数据显示'
',''旅游''相关''产品''订单''量''大幅增长'',“''异''地''跨''年''”''“''南北''互''跨''”''成''关键''词''。''▁''业''内''人士''认为'',''元''旦''假期''旅 
游'
'“''开''门''红''”''彰显''消费''潜力'',''预计''20''24''年''旅游''消费''有望''保持''上升''势头''。']
[1460, 29568, 28980, 2200, 28770, 29048, 28779, 28930, 29275, 28786, 2539, 29953, 30003, 1, 28758, 30345, 30229, 30365, 28787, 10, 3137, 28747, 28934, 29697, 18645, 28723, 4054, 266, 651, 29672, 29541, 28724, 4054, 2269, 12883, 28723, 4054, 521, 640, 25619, 28937, 22184, 710, 29596, 28765, 29649, 28747, 28811, 28809, 9356, 29410, 29649, 28811, 28762, 318, 29859, 28724, 28722, 28825, 28922, 1196, 64, 28723, 28934, 29697, 18645, 4054, 28809, 28889, 29208, 30060, 28811, 9466, 266, 1899, 28723, 1321, 10, 3137, 28747, 4054, 266, 4485, 398, 543, 4315, 28724]
新华社北京1月2日电(记者丁雅 ⁇ 、李唐宁)2024年元旦假期,旅游消费十分火爆。旅游平台数据显示,旅游相关产品订单量大幅增长,“异地跨年”“南北互跨”成关键词。 业内人士认为,元旦假期旅游“开门红”彰显消费潜力,预计2024年旅游消费有望保持上升势头。

这里可以看到,它无法正确识别 字,因为我们的语料库中没有,所以在一个充分大的语料上训练分词器是非常有必要的。但我们可以先忽略这个问题。

整个训练过程只需要几分钟。每个分词器会生成两个文件,一个模型文件和一个词表文件。比如中文的词表.vocab文件内容如下:

 0
 0
 0
 0
—— -0
经济 -1
国家 -2
美国 -3
▁但 -4
一个 -5
20 -6
我们 -7
政府 -8
中国 -9
可能 -10
他们 -11
欧洲 -12
问题 -13
...

这样我们有了训练好的BPE分词器,常用的操作如下:

sp.load("model_storage/source.model"# 加载分词器
print(sp.encode_as_pieces(source_text)) # 对文本分词
ids = sp.encode_as_ids(source_text) # 分词并编码成ID序列
print(sp.decode_ids(ids)) # ID序列还原成文本

定义数据加载器

@dataclass
class Batch:
    source: Tensor
    target: Tensor
    labels: Tensor
    num_tokens: int
    src_text: str = None
    tgt_text: str = None


class NMTDataset(Dataset):
    """Dataset for translation"""

    def __init__(self, text_df: pd.DataFrame, pad_idx: int = 0) -> None:
        """

        Args:
            text_df (pd.DataFrame): a DataFrame which contains the processed source and target sentences
        "
""
        # sorted by target length
        # text_df = text_df.iloc[text_df["target"].apply(len).sort_values().index]
        self.text_df = text_df

        self.padding_index = pad_idx

    def __getitem__(
        self, index: int
    ) -> Tuple[list[int], list[int], list[str], list[str]]:
        row = self.text_df.iloc[index]

        return (row.source_indices, row.target_indices, row.source, row.target)

    def collate_fn(
        self, batch: list[Tuple[list[int], list[int], list[str]]]
    ) -> Tuple[LongTensor, LongTensor, LongTensor]:
        source_indices = [x[0] for x in batch]
        target_indices = [x[1] for x in batch]
        source_text = [x[2] for x in batch]
        target_text = [x[3] for x in batch]

        source_indices = [torch.LongTensor(indices) for indices in source_indices]
        target_indices = [torch.LongTensor(indices) for indices in target_indices]

        # The  was added before the  token to ensure that the model can correctly identify the end of a sentence.
        source = pad_sequence(
            source_indices, padding_value=self.padding_index, batch_first=True
        )

        target = pad_sequence(
            target_indices, padding_value=self.padding_index, batch_first=True
        )

        labels = target[:, 1:]
        target = target[:, :-1]

        num_tokens = (labels != self.padding_index).data.sum()

        return Batch(source, target, labels, num_tokens, source_text, target_text)

    def __len__(self) -> int:
        return len(self.text_df)

首先定义数据集类,将数据转换成DataFrame操作比较方便,这里假设传入的内容已经经过分词器的向量化。

我们还需要自己实现 collate_fn ,把数据转换成我们需要的格式。

具体地,先将源和目标索引序列转换Tensor;然后按批次内最大长度进行填充,即每个批次最大长度是不同的。假设一个批大小为2的批次内数据为:

[[2, 12342, 123, 323, 3, 0, 0, 0],
 [2, 222, 23, 12, 123, 22, 22, 3]]

这里的2和3分别对应bos和eos的ID,0对应填充ID。可以看到eos id(3)是在pad id(0)之前,这样模型能正确区分句子的结束位置。

填充完之后就得到 (batch_size, seq_len) 形状的数据,这里 seq_len 是批次内最大长度。

其中 source 可以直接输入给编码器,但是解码器的输入以及预测的目标要注意。

举个例子,假设要翻译的一句话为:

['''我''喜''欢''打''篮''球''。''''']

注意后面有一个 填充标记,解码器的输入target会移除这句话的最后一个标记,这里是 ,得到:

target = ['''我''喜''欢''打''篮''球''。''']

我们要预测的标签 labels 会移除这句话的第一个标记,都是

labels = ['我''喜''欢''打''篮''球''。''''']

即解码器在输入 和编码器的编码后,要预测出' ';(结合mask)在输入 [ ,'我'] 之后要预测出' ';…;在输入 [' ', '我', '喜', '欢', '打', '篮', '球', '。'] 之后要预测出句子结束标记

有了这个类定义数据加载器就简单了:

DataLoader(
    dataset, # 数据集类的实例
    shuffle=True,
    batch_size=32,
    collate_fn=dataset.collate_fn,
)

定义训练函数

定义训练和评估函数:

def train(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    clip: float,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
) -> float:
    model.train()  # train mode

    total_loss = 0.0

    tqdm_iter = tqdm(data_loader)

    for source, target, labels, _ in tqdm_iter:
        source = source.to(device)
        target = target.to(device)
        labels = labels.to(device)

        logits = model(source, target)

        # loss calculation
        loss = criterion(logits, labels)

        loss.backward()

        if clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()

        total_loss += loss.item()

        description = f" TRAIN  loss={loss.item():.6f}, learning rate={scheduler.get_last_lr()[0]:.7f}"

        del loss

        tqdm_iter.set_description(description)

    # average training loss
    avg_loss = total_loss / len(data_loader)

    return avg_loss


@torch.no_grad()
def evaluate(
    model: nn.Module,
    data_loader: DataLoader,
    device: torch.device,
    criterion: torch.nn.Module,
) -> float:
    model.eval()  # eval mode

    total_loss = 0

    for source, target, labels, _ in tqdm(data_loader):
        source = source.to(device)
        target = target.to(device)
        labels = labels.to(device)

        # feed forward
        logits = model(source, target)
        # loss calculation
        loss = criterion(logits, labels)

        total_loss += loss.item()

        del loss

    # average validation loss
    avg_loss = total_loss / len(data_loader)
    return avg_loss


贪心搜索

贪心搜索或者说贪心解码,就是每次在预测下一个标记时都选取概率最大的那个。贪心搜索比较好实现,但是我们需要支持批操作,因为我们想在每个训练epoch结束后在验证集上计算一次BLEU分数。

 def _greedy_search(
        self, src: Tensor, src_mask: Tensor, max_gen_len: int, keep_attentions: bool
    ):
        memory = self.transformer.encode(src, src_mask)

        batch_size = src.shape[0]

        device = src.device

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)

        decoder_inputs = torch.LongTensor(batch_size, 1).fill_(self.bos_idx).to(device)

        eos_idx_tensor = torch.tensor([self.eos_idx]).to(device)

        finished = False

        while True:
            tgt_mask = self.generate_subsequent_mask(decoder_inputs.size(1), device)

            logits = self.lm_head(
                self.transformer.decode(
                    decoder_inputs,
                    memory,
                    tgt_mask=tgt_mask,
                    memory_mask=src_mask,
                    keep_attentions=keep_attentions,
                )
            )

            next_tokens = torch.argmax(logits[:, -1, :], dim=-1)

            # finished sentences should have their next token be a pad token
            next_tokens = next_tokens * unfinished_sequences + self.pad_idx * (
                1 - unfinished_sequences
            )

            decoder_inputs = torch.cat([decoder_inputs, next_tokens[:, None]], dim=-1)

            # set sentence to finished if eos_idx was found
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_idx_tensor.shape[0], 1)
                .ne(eos_idx_tensor.unsqueeze(1))
                .prod(dim=0)
            )

            # all sentences have eos_idx
            if unfinished_sequences.max() == 0:
                finished = True

            if decoder_inputs.shape[-1] >= max_gen_len:
                finished = True

            if finished:
                break

        return decoder_inputs

开始训练

定义训练参数:

import os
from dataclasses import dataclass
from typing import Tuple


@dataclass
class TrainArugment:
    """
    Create a 'data' directory and store the dataset under it
    "
""

    dataset_path: str = f"{os.path.dirname(__file__)}/data/wmt"
    save_dir = f"{os.path.dirname(__file__)}/model_storage"

    src_tokenizer_file: str = f"{save_dir}/source.model"
    tgt_tokenizer_path: str = f"{save_dir}/target.model"
    model_save_path: str = f"{save_dir}/best_transformer.pt"

    dataframe_file: str = "dataframe.{}.pkl"
    use_dataframe_cache: bool = True
    cuda: bool = True
    num_epochs: int = 40
    batch_size: int = 32
    gradient_accumulation_steps: int = 1
    grad_clipping: int = 0  # 0 dont use grad clip
    betas: Tuple[floatfloat] = (0.9, 0.997)
    eps: float = 1e-6
    label_smoothing: float = 0
    warmup_steps: int = 6000
    warmup_factor: float = 0.5
    only_test: bool = False
    max_gen_len: int = 60
    use_wandb: bool = True
    patient: int = 5
    gpus = [1, 2, 3]
    seed = 12345
    calc_bleu_during_train: bool = True


@dataclass
class ModelArugment:
    d_model: int = 512  # dimension of embeddings
    n_heads: int = 8  # numer of self attention heads
    num_encoder_layers: int = 6  # number of encoder layers
    num_decoder_layers: int = 6  # number of decoder layers
    d_ff: int = d_model * 4  # dimension of feed-forward network
    dropout: float = 0.1  # dropout ratio in the whole network






请到「今天看啥」查看全文