专栏名称: 学姐带你玩AI
这里有人工智能前沿信息、算法技术交流、机器学习/深度学习经验分享、AI大赛解析、大厂大咖算法面试分享、人工智能论文技巧、AI环境工具库教程等……学姐带你玩转AI!
目录
相关文章推荐
新浪科技  ·  【雷军:#小米SU7Ultra可提前小订# ... ·  22 小时前  
营销之美  ·  DeepSeek使用图鉴:人类和AI谁在玩弄谁? ·  2 天前  
营销之美  ·  DeepSeek使用图鉴:人类和AI谁在玩弄谁? ·  2 天前  
黄建同学  ·  Jim ... ·  2 天前  
AI前线  ·  微软力推新视频游戏 AI 模型,超 10 ... ·  2 天前  
51好读  ›  专栏  ›  学姐带你玩AI

阿里RE2文本匹配实战(附代码)

学姐带你玩AI  · 公众号  · AI 科技媒体  · 2024-07-31 18:21

正文

来源:投稿  作者:175
编辑:学姐

unset unset 引言 unset unset

今天我们来实现RE2进行文本匹配,模型实现参考了官方代码https://github.com/alibaba-edu/simple-effective-text-matching-pytorch。

本文的核心训练代码以及完整代码 文末领取

unset unset 模型实现 unset unset

RE2模型架构如上图所示。它的输入是两个文本片段,所有组件参数除了预测层和对齐层外都是共享的。上图虚线框出来的为一个Block,堆叠了N个block,文本片段之间的block内部通过对齐层进行交互。block之间通过增加的残差层进行连接。

下面我们从底向上依次实现,实现过程中参考了官方实现。

Embedding

嵌入层很简单没有使用字符嵌入,就是简单的单词嵌入。

class Embedding(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, dropout: float) -> None:
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x (Tensor): (batch_size, seq_len)

        Returns:
            Tensor: (batch_size, seq_len, embedding_dim)
        "
""
        return self.dropout(self.embedding(x))

Encoder

GeLU

首先实现GeLU,它是RELU的变种,后来被用到BERT中。其函数图像如下所示:

class GeLU(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        return  0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))

Linear

重写了线性层,activations开启GeLU激活。

class Linear(nn.Module):
    def __init__(
        self, in_features: int, out_features: int, activations: bool = True
    ) -> None:
        super().__init__()

        linear = nn.Linear(in_features, out_features)
        modules = [weight_norm(linear)]
        if activations:
            modules.append(GeLU())

        self.model = nn.Sequential(*modules)
        self.reset_parameters(activations)

    def reset_parameters(self, activations: bool) -> None:
        linear = self.model[0]
        nn.init.normal_(
            linear.weight,
            std=math.sqrt((2.0 if activations else 1.0) / linear.in_features),
        )
        nn.init.zeros_(linear.bias)

    def forward(self, x):
        return self.model(x)

nn.Conv1d

我们在比较聚合模型的实现中详细了解了torch.nn.Conv2d的实现以及CNN的一些基础概念。

这里我们通过torch.nn.Conv1d来实现论文中的多层卷积网络,本小结来详细了解Conv1d实现。

torch.nn.Conv1d
    in_channels: 输入的通道数,文本中为嵌入维度
    out_channels: 一个卷积核产生一个输出通道
    kernel_size: 卷积核的大小
    stride: 卷积步长,默认为1
    padding: 填充,默认为0
    bias(bool): 是否添加偏置,默认为True

我们以一个例子来说明它的计算过程,假设对于输入"W B G 是 冠 军",随机得到的嵌入为:

希望今天下午S13 WBG可以战胜T1。

import numpy as np
import torch.nn as nn
import torch

batch_size = 1
seq_len = 6
embed_size = 3


input_tensor = torch.rand(batch_size, seq_len, embed_size)
print(input_tensor)
print(input_tensor.shape)
tensor([[[0.9291, 0.8333, 0.5160],
         [0.0543, 0.8149, 0.5704],
         [0.7831, 0.2263, 0.9279],
         [0.0898, 0.0758, 0.4401],
         [0.4321, 0.2098, 0.6666],
         [0.6183, 0.0609, 0.2330]]])
torch.Size([1, 6, 3])

此时每个字符对应一个3维的嵌入向量,分别为:

W — [0.9291, 0.8333, 0.5160]
B — [0.0543, 0.8149, 0.5704]
G — [0.7831, 0.2263, 0.9279]
是 — [0.0898, 0.0758, 0.4401]
冠 — [0.4321, 0.2098, 0.6666]
军 — [0.6183, 0.0609, 0.2330]

但是Conv1d需要in_channels即嵌入维度为仅在batch_size后第一个位置,由[1, 6, 3]变成[1, 3, 6]。

input_tensor = input_tensor.permute(0, 2, 1)
# (batch_size, embed_size, seq_len)

图示如下:

文章还没发,结果被3:0了。

然后我们定义一个一维卷积:

input_channels = embed_size # 等于embed_size
output_channels = 2
kernel_size = 2 # kernel_size

conv1d = nn.Conv1d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size)

我们可以打印出来filter权重矩阵:

print(conv1d.weight)
print(conv1d.weight.shape)
Parameter containing:
tensor([[[ 0.0025,  0.3353],
         [ 0.0620, -0.3916],
         [-0.3458, -0.0610]],

        [[-0.1731, -0.0787],
         [-0.0419, -0.2555],
         [-0.1429,  0.1656]]], requires_grad=True)
torch.Size([2, 3, 2])

filter 权重的大小为 (2,3,2) shape[0]=2是filter个数;shape[1]=3是输入嵌入大小;shape[2]=2是filter大小。

默认是添加了偏置,一个filter一个偏置:

Parameter containing:
tensor([ 0.3760, -0.2881], requires_grad=True)
torch.Size([2])

我们这里有两个filter,所以有两个偏置。因为这里kernel_size=2,且步长stride=1,所以一个filter是如下的方式框住了两个字符嵌入,并且每次向右移动一格:

此时第一个filter的卷积操作计算为:

sum([[0.9291, 0.0543],           [[0.0025,  0.3353],
  [0.8333, 0.8149],     *       [0.0620, -0.3916],      +    0.3760(bias)
  [0.5160, 0.5704]]             [-0.3458, -0.0610])

第一个filter权重和这两个嵌入进行逐位置乘法产生一个标量(sum),最后加上第一个filter的偏置。

通过代码实现为:

# 开始计算卷积
# 前两个嵌入与卷积核权重逐元素乘法
result = input_tensor[:,:,:2]*conv1d.weight 
print(result)
# 结果求和再加上偏置
print(torch.sum(result[0]) + conv1d.bias[0])
print(torch.sum(result[1]) + conv1d.bias[1])
tensor([[[ 0.0024,  0.0182],
         [ 0.0517, -0.3191],
         [-0.1784, -0.0348]],

        [[-0.1608, -0.0043],
         [-0.0349, -0.2082],
         [-0.0737,  0.0944]]], grad_fn=)
         
tensor(-0.0841, grad_fn=# 第一个filter的结果
tensor(-0.6756, grad_fn=# 第二个filter的结果

这是第一次卷积的结果,第二次卷积把红框向右移动一格,又会有一个结果。

最终移动到输入的最后一个位置计算完毕:

共需要计算5次,因此最终一个filter会输出5个标量,共有2个filter,批大小为1。

如果用代码实现的话:

output = conv1d(input_tensor)
print(output)
print(output.shape)
tensor([[[-0.0841,  0.3468,  0.0447,  0.2508,  0.3288],
         [-0.6756, -0.3790, -0.5193, -0.3470, -0.4926]]],
       grad_fn=)
torch.Size([1, 2, 5])

可以看到output的形状为[1, 2, 5],第一列的计算结果和我们上面的一致。

shape[0]=1是批次内样本个数;``shape[1]=2是filter个数,也是想要输出的channel数;shape[2]=5`是卷积后的维度。

这里(忽略dilation)卷积后的维度大小由卷积核大小kernel_size、步长stride、填充padding以及输入序列长度seq_len决定:

Conv1d

下面实现RE2的多层卷积网络,首先是一个改写的Conv1d,用weight_norm进行权重归一化,采用GeLU激活函数。

class Conv1d(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, kernel_sizes: list[int]
    ) -> None:
        """

        Args:
            in_channels (int): the embedding_dim
            out_channels (int): number of filters
            kernel_sizes (list[int]): the size of kernel
        "
""
        super().__init__()

        out_channels = out_channels // len(kernel_sizes)

        convs = []
        # L_in is seq_len, L_out is output_dim of conv
        # L_out = (L_in + 2 * padding - kernel_size + 1)
        # and padding=(kernel_size - 1) // 2
        # L_out = (L_in + kernel_size - 1 - kernel_size + 1) = L_in
        for kernel_size in kernel_sizes:
            conv = nn.Conv1d(
                in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2
            )
            convs.append(nn.Sequential(weight_norm(conv), GeLU()))
        # output shape of each conv is (batch_size, out_channels(new), seq_len)

        self.model = nn.ModuleList(convs)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        for seq in self.model:
            conv = seq[0]
            nn.init.normal_(
                conv.weight,
                std=math.sqrt(2.0 / (conv.in_channels * conv.kernel_size[0])),
            )
            nn.init.zeros_(conv.bias)

    def forward(self, x: Tensor) -> Tensor:
        """

        Args:
            x (Tensor): shape (batch_size, embedding_dim, seq_len)

        Returns:
            Tensor:
        "
""
        # back to (batch_size, out_channels, seq_len)
        return torch.cat([encoder(x) for encoder in self.model], dim=1)

Encoder实现

class Encoder(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        kernel_sizes: list[int],
        encoder_layers: int,
        dropout: float,
    ) -> None:
        """_summary_

        Args:
            input_size (int): embedding_dim or embedding_dim + hidden_size
            hidden_size (int): hidden size
            kernel_sizes (list[int]): the size of kernels
            encoder_layers (int): number of conv layers
            dropout (float): dropout ratio
        "
""
        super().__init__()

        self.encoders = nn.ModuleList(
            [
                Conv1d(
                    in_channels=input_size if i == 0 else hidden_size,
                    out_channels=hidden_size,
                    kernel_sizes=kernel_sizes,
                )
                for i in range(encoder_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        """forward in encoder

        Args:
            x (Tensor): (batch_size, seq_len, input_size)
            mask (Tensor): (batch_size, seq_len, 1)

        Returns:
            Tensor: _description_
        "
""
        # x (batch_size, input_size, seq_len)
        x = x.transpose(1, 2)
        # mask (batch_size, 1, seq_len)
        mask = mask.transpose(1, 2)

        for i, encoder in enumerate(self.encoders):
            # fills elements of x with 0.0 where mask is False
            x.masked_fill_(~mask, 0.0)
            # using dropout
            if i > 0:
                x = self.dropout(x)
            # returned x (batch_size, hidden_size, seq_len)
            x = encoder(x)

        # apply dropout
        x = self.dropout(x)
        # (batch_size, seq_len, hidden_size)
        return x.transpose(1, 2)

这里用多层Conv1d作为编码器,要注意第0层和其他层的区别,第0层的嵌入维度是input_size即``embedding_size,经过第0层的Conv1d后维度变成两hidden_size,所以后续层参数in_channels为hidden_size`。

这里用x.masked_fill_(~mask, 0.0)设置mask矩阵中的填充位为0。

不采用RNN作为编码器,作者认为RNN速度慢且没有带来性能上的提升。

Alignment

然后实现对齐层,所谓的对齐就是让两个序列进行交互,这里采用基于注意力交互的方式。

class Alignment(nn.Module):
    def __init__(
        self, input_size: int, hidden_size: int, dropout: float, project_func: str
    ) -> None:
        """

        Args:
            input_size (int): embedding_dim  + hidden_size  or embedding_dim  + hidden_size * 2
            hidden_size (int): hidden size
            dropout (float): dropout ratio
            project_func (str): identity or linear
        "
""
        super().__init__()

        self.temperature = nn.Parameter(torch.tensor(1 / math.sqrt(hidden_size)))

        if project_func != "identity":
            self.projection = nn.Sequential(
                nn.Dropout(dropout), Linear(input_size, hidden_size)
            )
        else:
            self.projection = nn.Identity()

    def forward(self, a: Tensor, b: Tensor, mask_a: Tensor, mask_b: Tensor) -> Tensor:
        """

        Args:
            a (Tensor): (batch_size, seq_len, input_size)
            b (Tensor): (batch_size, seq_len, input_size)
            mask_a (Tensor):  (batch_size, seq_len, 1)
            mask_b (Tensor):  (batch_size, seq_len, 1)

        Returns:
            Tensor: _description_
        "
""
        # if projection == 'linear' : self.projection(*) -> (batch_size, seq_len,  hidden_size) -> transpose(*) -> (batch_size, hidden_size,  seq_len)
        # if projection == 'identity' : self.projection(*) -> (batch_size, seq_len, input_size) -> transpose(*) -> (batch_size, input_size,  seq_len)
        # attn (batch_size, seq_len_a,  seq_len_b)
        attn = (
            torch.matmul(self.projection(a), self.projection(b).transpose(1, 2))
            * self.temperature
        )
        # mask (batch_size, seq_len_a, seq_len_b)
        mask = torch.matmul(mask_a.float(), mask_b.transpose(1, 2).float())
        mask = mask.bool()
        # fills elements of x with 0.0(after exp) where mask is False
        attn.masked_fill_(~mask, -1e7)
        # attn_a (batch_size, seq_len_a,  seq_len_b)
        attn_a = F.softmax(attn, dim=1)
        # attn_b (batch_size, seq_len_a,  seq_len_b)
        attn_b = F.softmax(attn, dim=2)
        # feature_b  (batch_size, seq_len_b,  seq_len_a) x (batch_size, seq_len_a, input_size)
        # -> (batch_size, seq_len_b,  input_size)
        feature_b = torch.matmul(attn_a.transpose(1, 2), a)
        # feature_a  (batch_size, seq_len_a,  seq_len_b) x (batch_size, seq_len_b, input_size)
        # -> (batch_size, seq_len_a,  input_size)
        feature_a = torch.matmul(attn_b, b)

        return  feature_a, feature_b

增强残差连接


class AugmentedResidualConnection(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor, res: Tensor, i: int) -> Tensor:
        """

        Args:
            x (Tensor): the output of pre block (batch_size, seq_len, hidden_size)
            res (Tensor): (batch_size, seq_len, embedding_size) or (batch_size, seq_len, embedding_size + hidden_size)
                res[:,:,hidden_size:] is the output of Embedding layer
                res[:,:,:hidden_size] is the output of previous two block
            i (int): layer index

        Returns:
            Tensor: (batch_size, seq_len,  hidden_size  + embedding_size)
        "
""
        if i == 1:
            # (batch_size, seq_len,  hidden_size  + embedding_size)
            return torch.cat([x, res], dim=-1)
        hidden_size = x.size(-1)
        # (res[:, :, :hidden_size] + x) is the summation of the output of previous two blocks
        # x (batch_size, seq_len, hidden_size)
        x = (res[:, :, :hidden_size] + x) * math.sqrt(0.5)
        # (batch_size, seq_len,  hidden_size  + embedding_size)
        return torch.cat([x, res[:, :, hidden_size:]], dim=-1)

融合层

class Fusion(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:
        """

        Args:
            input_size (int): embedding_dim  + hidden_size  or embedding_dim  + hidden_size * 2
            hidden_size (int): hidden size
            dropout (float): dropout ratio
        "
""
        super().__init__()

        self.dropout = nn.Dropout(dropout)
        self.fusion1 = Linear(input_size * 2, hidden_size, activations=True)
        self.fusion2 = Linear(input_size * 2, hidden_size, activations=True)
        self.fusion3 = Linear(input_size * 2, hidden_size, activations=True)
        self.fusion = Linear(hidden_size * 3, hidden_size, activations=True)

    def forward(self, x: Tensor, align: Tensor) -> Tensor:
        """

        Args:
            x (Tensor): input (batch_size, seq_len, input_size)
            align (Tensor): output of Alignment (batch_size, seq_len,  input_size)

        Returns:
            Tensor: (batch_size, seq_len, hidden_size)
        "
""
        # x1 (batch_size, seq_len, hidden_size)
        x1 = self.fusion1(torch.cat([x, align], dim=-1))
        # x2 (batch_size, seq_len, hidden_size)
        x2 = self.fusion1(torch.cat([x, x - align], dim=-1))
        # x3 (batch_size, seq_len, hidden_size)
        x3 = self.fusion1(torch.cat([x, x * align], dim=-1))
        # x (batch_size, seq_len, hidden_size * 3)
        x = torch.cat([x1, x2, x3], dim=-1)
        x = self.dropout(x)
        # (batch_size, seq_len, hidden_size)
        return self.fusion(x)

池化层

class Pooling(nn.Module):
    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        """

        Args:
            x (Tensor): (batch_size, seq_len, hidden_size)
            mask (Tensor): (batch_size, seq_len, 1)

        Returns:
            Tensor: (batch_size, hidden_size)
        "
""
        # max returns a namedtuple (values, indices), we only need values
        return x.masked_fill(~mask, -float("inf")).max(dim=1)[0]

池化层取时间步维度上的最大值。

预测层

class Prediction(nn.Module):
    def __init__(self, hidden_size: int, num_classes: int, dropout: float) -> None:
        super().__init__()
        self.dense = nn.Sequential(
            nn.Dropout(dropout),
            Linear(hidden_size * 4, hidden_size, activations=True),
            nn.Dropout(dropout),
            Linear(hidden_size, num_classes),
        )

    def forward(self, a: Tensor, b: Tensor) -> Tensor:
        """

        Args:
            a (Tensor): (batch_size, hidden_size)
            b (Tensor): (batch_size, hidden_size)

        Returns:
            Tensor: (batch_size, num_classes)
        "
""
        return self.dense(torch.cat([a, b, a - b, a * b], dim=-1))

预测层比较简单,再次对输入向量进行了一个融合:

RE2实现

RE2的实现时上述层的堆叠:

class RE2(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()

        self.embedding = Embedding(args.vocab_size, args.embedding_dim, args.dropout)

        self.connection = AugmentedResidualConnection()

        self.blocks = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "encoder": Encoder(
                            args.embedding_dim
                            if i == 0
                            else args.embedding_dim + args.hidden_size,
                            args.hidden_size,
                            args.kernel_sizes,
                            args.encoder_layers,
                            args.dropout,
                        ),
                        "alignment": Alignment(
                            args.embedding_dim + args.hidden_size
                            if i == 0
                            else args.embedding_dim + args.hidden_size * 2,
                            args.hidden_size,
                            args.dropout,
                            args.project_func,
                        ),
                        "fusion": Fusion(
                            args.embedding_dim + args.hidden_size
                            if i == 0
                            else args.embedding_dim + args.hidden_size * 2,
                            args.hidden_size,
                            args.dropout,
                        ),
                    }
                )
                for i in range(args.num_blocks)
            ]
        )

        self.pooling = Pooling()
        self.prediction = Prediction(args.hidden_size, args.num_classes, args.dropout)

    def forward(self, a: Tensor, b: Tensor, mask_a: Tensor, mask_b: Tensor) -> Tensor:
        """
        Args:
            a (Tensor): (batch_size, seq_len)
            b (Tensor): (batch_size, seq_len)
            mask_a (Tensor): (batch_size, seq_len, 1)
            mask_b (Tensor): (batch_size, seq_len, 1)

        Returns:
            Tensor: (batch_size, num_classes)
        "
""
        # a (batch_size, seq_len, embedding_dim)
        a = self.embedding(a)
        # b (batch_size, seq_len, embedding_dim)
        b = self.embedding(b)

        res_a, res_b = a, b

        for i, block in enumerate(self.blocks):
            if i > 0:
                # a (batch_size, seq_len, embedding_dim + hidden_size)
                a = self.connection(a, res_a, i)
                # b (batch_size, seq_len, embedding_dim + hidden_size)
                b = self.connection(b, res_b, i)
                # now embeddings saved to res_a[:,:,hidden_size:]
                res_a, res_b = a, b
            # a_enc (batch_size, seq_len, hidden_size)
            a_enc = block["encoder"](a, mask_a)
            # b_enc (batch_size, seq_len, hidden_size)
            b_enc = block["encoder"](b, mask_b)
            # concating the input and output of encoder
            # a (batch_size, seq_len, embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
            a = torch.cat([a, a_enc], dim=-1)
            # b (batch_size, seq_len, embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
            b = torch.cat([b, b_enc], dim=-1)
            # align_a (batch_size, seq_len,  embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
            # align_b (batch_size, seq_len,  embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
            align_a, align_b = block["alignment"](a, b, mask_a, mask_b)
            # a (batch_size, seq_len,  hidden_size)
            a = block["fusion"](a, align_a)
            # b (batch_size, seq_len,  hidden_size)
            b = block["fusion"](b, align_b)
        # a (batch_size, hidden_size)
        a = self.pooling(a, mask_a)
        # b (batch_size, hidden_size)
        b = self.pooling(b, mask_b)
        # (batch_size, num_classes)
        return self.prediction(a, b)

注意不同块之间输入维度的区别。

unset unset 数据准备 unset unset

在→文章←中数据准备这部分内容有详细的解释。

from collections import defaultdict
from tqdm import tqdm
import numpy as np
import json
from torch.utils.data import Dataset
import pandas as pd
from typing import Tuple

UNK_TOKEN = ""
PAD_TOKEN = ""


class Vocabulary:
    """Class to process text and extract vocabulary for mapping"""

    def __init__(self, token_to_idx: dict = None, tokens: list[str] = None) -> None:
        """
        Args:
            token_to_idx (dict, optional): a pre-existing map of tokens to indices. Defaults to None.
            tokens (list[str], optional): a list of unique tokens with no duplicates. Defaults to None.
        "
""

        assert any(
            [tokens, token_to_idx]
        ), "At least one of these parameters should be set as not None."
        if token_to_idx:
            self._token_to_idx = token_to_idx
        else:
            self._token_to_idx = {}
            if PAD_TOKEN not in tokens:
                tokens = [PAD_TOKEN] + tokens

            for idx, token in enumerate(tokens):
                self._token_to_idx[token] = idx

        self._idx_to_token = {idx: token for token, idx in self._token_to_idx.items()}

        self.unk_index = self._token_to_idx[UNK_TOKEN]
        self.pad_index = self._token_to_idx[PAD_TOKEN]

    @classmethod
    def build(
        cls,
        sentences: list[list[str]],
        min_freq: int = 2,
        reserved_tokens: list[str] = None,
    ) -> "Vocabulary":
        """Construct the Vocabulary from sentences

        Args:
            sentences (list[list[str]]): a list of tokenized sequences
            min_freq (int, optional): the minimum word frequency to be saved. Defaults to 2.
            reserved_tokens (list[str], optional): the reserved tokens to add into the Vocabulary. Defaults to None.

        Returns:
            Vocabulary: a Vocubulary instane
        "
""

        token_freqs = defaultdict(int)
        for sentence in tqdm(sentences):
            for token in sentence:
                token_freqs[token] += 1

        unique_tokens = (reserved_tokens if reserved_tokens else []) + [UNK_TOKEN]
        unique_tokens += [
            token
            for token, freq in token_freqs.items()
            if freq >= min_freq and token != UNK_TOKEN
        ]
        return cls(tokens=unique_tokens)

    def __len__(self) -> int:
        return len(self._idx_to_token)

    def __getitem__(self, tokens: list[str] | str) -> list[int] | int:
        """Retrieve the indices associated with the tokens or the index with the single token

        Args:
            tokens (list[str] | str): a list of tokens or single token

        Returns:
            list[int] | int: the indices or the single index
        "
""
        if






请到「今天看啥」查看全文