专栏名称: 机器学习算法与自然语言处理
一个有情怀的公众号。机器学习、自然语言处理、算法等知识集中营、期待与你相遇~
51好读  ›  专栏  ›  机器学习算法与自然语言处理

Transformer变体层出不穷,它们都长什么样?

机器学习算法与自然语言处理  · 公众号  ·  · 2021-02-27 00:00

正文

公众号关注 “ ML_NLP
设为 “ 星标 ”,重磅干货,第一时间送达!

转载自|PaperWeekly

©PaperWeekly 原创 · 作者|上杉翔二

单位|悠闲会

研究方向|信息检索


不知不觉 Transformer 已经逐步渗透到了各个领域,就其本身也产生了相当多的变体,如上图。本篇文章想大致按照这个图,选一些比较精彩的变体整理,话不多说直接开始。




Transformer-XL
论文标题:
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context


收录会议:
ACL 2019


论文链接:
https://arxiv.org/abs/1901.02860


代码链接:
https://github.com/kimiyoung/transformer-xl


上图上标的是“Recurrence”,首先看看这篇文章聚焦的 2 个问题:

  • 虽然 Transformer 可以学习到输入文本的长距离依赖关系和全局特性,但是!需要事先设定输入长度,这导致了其对于长程关系的捕捉有了一定限制。

  • 出于效率的考虑,需要对输入的整个文档进行分割(固定的),那么每个序列的计算相互独立,所以只能够学习到同个序列内的语义联系,整体上看,这将会导致文档语意上下文的碎片化(context fragmentation)。

那么如何学习更长语义联系?


segment-level Recurrence
segment-level 循环机制。如上图左边为原始 Transformer,右边为 Transformer-XL,Transformer-XL 模型的计算当中加入绿色连线,使得当层的输入取决于本序列和上一个序列前一层的输出。这样每个序列计算后的隐状态会参与到下一个序列的计算当中,使得模型能够学习到跨序列的语义联系(看动图可能更好理解)。



是第 个 segment 的第 n 层隐向量,那么第 r+1 个的第 n 层的隐向量的计算,就是上面这套公式。
  • 其中 SG 是是 stop-gradient,不再对 的隐向量做反向传播(这样虽然在计算中运用了前一个序列的计算结果,但是在反向传播中并不对其进行梯度的更新,毕竟前一个梯度肯定不受影响)。
  • 是对两个隐向量序列沿长度 L 方向的拼接 。3 个 W 分别对应 query,key 和 value 的转化矩阵,需要注意的是!k 和 v 的 W 用的是 ,而 q 是用的 ,即 kv 是用的拼接之后的 h,而 q 用的是原始序列的信息。感觉可以理解为以原始序列查拼接序列,这样可以得到一些前一个序列的部分信息以实现跨语义。
  • 最后的公式是标准的 Transformer。
还有一点设计是,在 评估预测模型 的时候它是会连续计算前 L 个长度的隐向量的( 训练 的时候只有前一个,缓存在内存中)。
即每一个位置的隐向量,除了自己的位置,都跟下一层中前(L-1)个位置的 token 存在依赖关系,而且每往下走一层,依赖关系长度会增加(L-1),这样能使跨语义更加的深入。



只看看 XL 多头注意力的 forward 的不同地方吧。

def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
             #w是上一层的输出,r是相对位置嵌入(在下一节),r_w_bias是u,r_r_bias是v向量
            qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

            if mems is not None#mems就是前一些序列的向量,不为空
                cat = torch.cat([mems, w], 0#就拼起来
                if self.pre_lnorm: #如果有正则化
                    w_heads = self.qkv_net(self.layer_norm(cat)) #这个net是nn.Linear,即qkv的变换矩阵W参数
                else:
                    w_heads = self.qkv_net(cat)#没有正则就直接投影一下
                r_head_k = self.r_net(r)#也是nn.Linear

                w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1#复制3份
                w_head_q = w_head_q[-qlen:] #q的W不要拼接的mems
            else:#没有mems,就正常的计算
                if self.pre_lnorm:
                    w_heads = self.qkv_net(self.layer_norm(w))
                else:
                    w_heads = self.qkv_net(w)
                r_head_k = self.r_net(r)

                w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3 , dim=-1)

            klen = w_head_k.size(0)
            #qlen是序列长度,bsz是batch size,n_head是注意力头数,d_head是每个头的隐层维度
            w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
            w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
            w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head

            r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head

            ####计算注意力的四个部分
            #AC是指相对位置的公式里的a和c两个部分,相对位置在下一节做笔记
            rw_head_q = w_head_q + r_w_bias                                         # qlen x bsz x n_head x d_head

            #爱因斯坦简记法求和sum,统一的方式表示各种各样的张量运算
            AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head

            #BD是指相对位置的公式里的b和d两个部分
            rr_head_q = w_head_q + r_r_bias

            BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
            BD = self._rel_shift(BD)

            # [qlen x klen x bsz x n_head]
            attn_score = AC + BD #最后的结果
            attn_score.mul_(self.scale)#进行放缩


Relative Position Encodings

相对位置编码。原始 Transformer 采用了正弦/余弦函数来编码绝对位置信息。然而因为 Transformer-XL 会有多个句子,所以还是绝对位置,那么两个句子的相同位置是同样的编码。
比如 [0, 1, 2, 3] 在两个句子 concat 之后就变成了 [0, 1, 2, 3, 0, 1, 2, 3],句子不连续,而且每次拼的句子会不一样,也不能找到适合的绝对位置编码。所以这里使用相对位置编码。

上图是原始 Transformer 和 Transformer-XL 的比较,其中 E 表示词的 Embedding,而 U 表示绝对位置编码。这大一堆看起来奇奇怪怪,实际上 Transformer 的注意力计算是 的分解,即先编码 Q(当前词 i)和 K(其他的词 j)然后算内积,位置编码是直接 add 在词嵌入上面的。

而 Transformer-XL 的改变是:

  • 把 j 的绝对位置 U 换成了相对位置 R,该相对位置表示也是一个正弦函数表示(i 和 j 的相对位置向量,j 是之前的序列,所以相减一定是正数)。R 不是通过学习得到的,好处是预测时,可以使用比训练距离更长的位置向量。
  • 使用两个可学习参数 u 和 v 替代了中的 query i 的位置映射。这里是由于每次计算 query 向量是固定的,不需要编码。
  • 每一层的 Attention 计算都要相对位置编码。Transformer 里面只有 input 的时候会加,而 XL 需要每层。

细细思考,这 attention 的四个部分各有玄机:
  • a. 基于内容的“寻址”,即没有添加原始位置编码的原始向量,
  • b. 基于内容的位置偏置,即相对于当前内容的位置偏差,
  • c. 全局的内容偏置,用于衡量 key 的重要性,query 固定查
  • d. 全局的位置偏置,根据 query 和 key 之间的距离调整重要性,query 固定查


相对位置编码的代码为:


class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()
        self.demb = demb #编码维度
        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) #间隔频率

    def forward(self, pos_seq):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq) #序列的位置向量 operation 间隔
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1#正弦余弦
        return pos_emb[:,None,:] #直接返回R,非学习矩阵R


简单把编码维度设置为 10,查询向量也是 10 个,存储之前的序列也是 10,有以下结果:


>>> import torch
>>> inv_freq = 1 / (10000 ** (torch.arange(0.0102.0) / 10))
>>> inv_freq
tensor([1.0000e+001.5849e-012.5119e-023.9811e-036.3096e-04])
>>> pos_seq=torch.arange(20-1-1-1.0#qlen+mlen,即10+10的维度然后逆序
>>> pos_seq
tensor([19.18.17.16. 15.14.13.12.11.10.,  9.,  8.,  7.,  6.,
         5.,  4.,  3.,  2.,  1.,  0.])
>>> sinusoid_inp = torch.ger(pos_seq,inv_freq)
>>> sinusoid_inp
tensor([[1.9000e+013.0113e+004.7726e-017.5640e-021.1988e-02],
        [1.8000e+012.8528e+004.5214e-017.1659e-021.1357e-02],
        [1.7000e+012.6943e+004.2702e-016.7678e-021.0726e-02],
        [1.6000e+012.5358e+004.0190e-016.3697e-021.0095e-02],
        [1.5000e+012.3773e+003.7678e-015.9716e-029.4644e-03],
        [1.4000e+012.2189e+003.5166e-015.5735e-028.8334e-03],
        [1.3000e+012.0604e+003.2655e-015.1754e-028.2024e-03],
        [1.2000e+011.9019e+003.0143e-014.7773e-027.5715e-03],
        [1.1000e+011.7434e+002.7631e-014.3792e-026.9405e-03],
        [1.0000e+011.5849e+002.5119e-013.9811e-026.3096e-03],
        [9.0000e+001.4264e+002.2607e-013.5830e-025.6786e-03],
        [8.0000e+001.2679e+002.0095e-013.1849e-025.0477e-03],
        [7.0000e+001.1094e+001.7583e-012.7867e-024.4167e-03],
        [6.0000e+009.5094e-011.5071e-012.3886e-023.7857e-03],
        [5.0000e+007.9245e-011.2559e-011.9905e-023.1548e-03],
        [4.0000e+006.3396e-011.0048e-011.5924e-022.5238e-03],
        [3.0000e+004.7547e-017.5357e-021.1943e-021.8929e-03],
        [2.0000e+003.1698e-015.0238e-027.9621e-031.2619e-03],
        [1.0000e+001.5849e-012.5119e-023.9811e-036.3096e-04],
        [0.0000e+000.0000e+000.0000e+000.0000e+000.0000e+00]])
>>> sinusoid_inp.sin()
tensor([[ 1.4988e-01,  1.2993e-01,  4.5935e-01,  7.5568e-02,  1.1988e-02],
        [-7.5099e-01,  2.8479e-01,  4.3689e-01,  7.1598e-02,  1.1357e-02],
        [-9.6140e-01,  4.3251e-01,  4.1416e-01,  6.7627e-02,  1.0726e-02],
        [-2.8790e-01,  5.6939e-01,  3.9117e-01,  6.3654e-02,  1.0095e-02],
        [ 6.5029e-01,  6.9200e-01,  3.6793e-01,  5.9681e-02,  9.4642e-03],
        [ 9.9061e-01,  7.9726e-01,  3.4446e-01,  5.5706e-02,  8.8333e-03],
        [ 4.2017e-01,  8.8254e-01,  3.2077e-01,  5.1731e-02,  8.2024e-03],
        [-5.3657e-01,  9.4569e-01,  2.9688e-01,  4.7755e-02,  7.5714e-03],
        [-9.9999e-01,  9.8514e-01,  2.7281e-01,  4.3778e-02,  6.9405e-03],
        [-5.4402e-01,  9.9990e-01,  2.4856e-01,  3.9800e-02,  6.3095e-03],
        [ 4.1212e-01,  9.8959e-01,  2.2415e-01,  3.5822e-02,  5.6786e-03],
        [ 9.8936e-01,  9.5448e-01,  1.9960e-01,  3.1843e-02,  5.0476e-03],
        [ 6.5699e-01,  8.9544e-01,  1.7493e-01,  2.7864e-02,  4.4167e-03],
        [-2.7942e-01,  8.1396e-01,  1.5014e-01,  2.3884e-02,  3.7857e-03],
        [-9.5892e-01,  7.1207e-01,  1.2526e-01,  1.9904e-02,  3.1548e-03],
        [-7.5680e-01,  5.9234e-01,  1.0031e-01,  1.5924e-02,  2.5238e-03],
        [ 1.4112e-01,  4.5775e-01






请到「今天看啥」查看全文