专栏名称: GiantPandaCV
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
最高人民法院  ·  Yo!听说代表法要修改啦? ·  4 天前  
51好读  ›  专栏  ›  GiantPandaCV

梳理RWKV 4,5(Eagle),6(Finch)架构的区别以及个人理解和建议

GiantPandaCV  · 公众号  ·  · 2024-04-25 00:00

正文

0x0. 前言

RWKV系列模型的迭代速度比较快,主要是下面两篇paper:

  • RWKV: Reinventing RNNs for the Transformer Era:https://arxiv.org/abs/2305.13048
  • Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence:https://arxiv.org/abs/2404.05892

之前我解析过RWKV-4的结构和代码实现(https://zhuanlan.zhihu.com/p/653327189),这里再把它和RWKV5,RWKV6放在一起进行对比解析一下。

回顾一下,RWKV 4论文中对RWKV名字含义有说明:

  • R: Receptance vector acting as the acceptance of past information. 类似于LSTM的“门控单元”
  • W: Weight is the positional weight decay vector. A trainable model parameter. 可学习的位置权重衰减向量,什么叫“位置权重衰减”看下面的公式(14)
  • K: Key is a vector analogous to K in traditional attention. 与传统自注意力机制
  • V : Value is a vector analogous to V in traditional attention. 与传统自注意力机制相同

如果不想看下面的细节,可以直接跳到结论那一节,我个人有一些尖锐的评价。

0x1 RWKV 模型架构回顾

RWKV模型由一系列RWKV Block模块堆叠而成,RWKV Block的结构如下图所示:

在这里插入图片描述

RWKV Block又主要由Time Mixing和Channel Mixing组成。

Time Mixing模块的公式定义如下:

在这里插入图片描述

这里的 表示当前时刻, 看成当前的token,而 看成前一个token, 的计算与传统Attention机制类似,通过将当前输入token与前一时刻输入token做线性插值,体现了recurrence的特性。然后 的计算则是对应注意力机制的实现,这个实现也是一个过去时刻信息与当前时刻信息的线性插值,注意到这里是指数形式并且当前token和之前的所有token都有一个指数衰减求和的关系,也正是因为这样让 拥有了线性attention的特性。

然后RWKV模型里面除了使用Time Mixing建模这种Token间的关系之外,在Token内对应的隐藏层维度上RWKV也进行了建模,即通过Channel Mixing模块。

在这里插入图片描述

Channel Mixing的意思就是在特征维度上做融合。假设特征向量维度是d,那么每一个维度的元素都要接收其他维度的信息,来更新它自己。特征向量的每个维度就是一个“channel”(通道)。

下图展示了RWKV模型整体的结构:

在这里插入图片描述

这里提到的token shift就是上面对r, k, v计算的时候类似于卷积滑窗的过程。然后我们可以看到当前的token不仅仅可以通过Time Mixing的Token Shift和隐藏状态States(即 )和之前的token建立联系,也可以通过Channel Mixing的Token Shift和之前的token建立联系,类似于拥有了 全局感受野

这里的讲解是以RWKV 4为例的,无论是RWKV的哪个版本,基本架构都是类似的,区别就在于对Time Mixing,Token shift以及Channel Mixing操作的修改。接下来的几节,就重点关注一下这个改动即可把握RWKV系列模型的进展。

0x2. RWKV 4的具体实现

主要关注Time Mixing,Channel Mixing,Token Shift的实现,代码实现见。https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py#L57-L93

0x2.1 RWKV 4 Channel Mixing

@torch.jit.script_method
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
        state[5*i+0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

参考RWKV 4 paper的Channel Mixing的公式来看:

在这里插入图片描述

在channel_mixing函数里面, 对应当前token的词嵌入向量, 表示前一个token的词嵌入向量。剩下的变量都是RWKV的可学习参数。然后代码里面会动态更新state,让 总是当前token的前一个token的词嵌入。

0x2.2 RWKV4 Time mixing函数

@torch.jit.script_method
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
        xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) # 对应下面的公式12的后半部分
        xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) # 对应下面的公式13的后半部分
        xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) # 对应下图中的公式11的后半部分
        state[5*i+1] = x
        r = torch.sigmoid(rw @ xr) # 对应下面公式11的前半部分和公式15里的sigmoid
        k = kw @ xk # 对应下面的公式12的前半部分
        v = vw @ xv # 对应下面的公式13的前半部分
        
        aa = state[5*i+2]
        bb = state[5*i+3]
        pp = state[5*i+4]
        ww = time_first + k # 对应下面的RWKV可以写成递归形式图中的{u+k_t}
        qq = torch.maximum(pp, ww) # 对应e^{u+k_t}的数值稳定性维护,维护最大值
        e1 = torch.exp(pp - qq) 
        e2 = torch.exp(ww - qq) # e1和e2分别对应分子分母的a_{t}和b_{t}的稳定性维护
        a = e1 * aa + e2 * v
        b = e1 * bb + e2
        wkv = a / b # 对应wkv_t的计算
        ww = pp + time_decay 
        qq = torch.maximum(ww, k)
        e1 = torch.exp(ww - qq) # 对应下面的RWKV可以写成递归形式图中的a_t计算的前半部分
        e2 = torch.exp(k - qq) # 对应下面的RWKV可以写成递归形式图中的a_t计算的后半部分
        state[5*i+2] = e1 * aa + e2 * v
        state[5*i+3] = e1 * bb + e2
        state[5*i+4] = qq
        return ow @ (r * wkv)

仍然是要对照公式来看:

在这里插入图片描述

然后这里有一个trick,就是对 的计算可以写成RNN的递归形式:

这样上面的公式就很清晰了,还需要注意的是在实现的时候由于有exp的存在,为了保证数值稳定性实现的时候减去了每个公式涉及到的e的指数部分的Max。

关于RWKV 的attention部分( )计算如果你有细节不清楚,建议观看一下这个视频:解密RWKV线性注意力的进化过程(https://www.bilibili.com/video/BV1zW4y1D7Qg/?spm_id_from=333.337.search-card.all.click&vd_source=4dffb0fbabed4311f4318e8c6d253a10) 。

仔细理解上面的代码之后对照RWKV 5/6的公式理解后面的代码实现就不是很难了。

0x3. RWKV Eagle (RWKV 5)的具体实现

代码见:https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_v5_demo.py#L159-L195

0x3.1 RWKV 5 Channel Mixing

这个是RWKV 5的Channel Mixing的代码实现,可以对比一下RWKV 4的实现。

@MyFunction
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        i0 = (2+self.head_size)*i+0
        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
        state[i0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

RWKV 4的Channel Mixing的代码实现为:

@torch.jit.script_method
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
        state[5*i+0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

这里的 表示的是RWKV有多少层,在RWKV4的每一层中Channel Mixing记录一个状态,而每一个Time Mixing则记录4个状态,所以一共是5个状态。而RWKV 5中每一层现在记录了 2+self.head_size 个状态,Channel Mixing记录的状态以及计算过程和RWKV 4是完全一样的。

0x3.2 RWKV 5 Time Mixing

@MyFunction
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
        H = self.n_head
        S = self.head_size

        i1 = (2+S)*i+1
        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
        state[i1] = x

        r = (rw @ xr).view(H, 1, S)
        k = (kw @ xk).view(H, S, 1)
        v = (vw @ xv).view(H, 1, S)
        g = F.silu(gw @ xg)

        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

        x = torch.zeros(H, S)
        a = k @ v
        x = r @ (time_first * a + s)
        s = a + time_decay * s
    
        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
        x = x.flatten()

        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
        return ow @ x

RWKV 5 Time Mixing的改动主要就在这个Time Mixing模块了,对应paper里面下面这一页:

在这里插入图片描述

这里的最大的改进应该是现在的计算是分成了 H = self.n_head 个头,然后每个头的计算结果都被存到了state里。相比于RWKV-4,这种改进可以类比于Transformer的单头自注意力机制改到多头注意力机制。

0x4. RWKV Finch (RWKV 6)的具体实现

代码见:https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_v6_demo.py#L157-L199

首先RWKV 6相比于RWKV 5在Token Shift上进行了改进,具体看下面的中间底部和右下角的图,分别是RWKV 4/5的Token Shift方式和RWKV 6的Token Shift方式。

Paper里面对RWKV 6的Token Shit也有详细描述:

翻译一下:在Finch Token Shift中使用的 之间依赖数据的线性插值(ddlerp)定义如下:

------------------------------------------ (14)

-------------------------------------- (15)

其中, 和每个 引入了一个维度为D的可训练向量,每个 引入了新的可训练权重矩阵。对于特殊情况 ,我们引入了双倍大小的可训练权重矩阵







请到「今天看啥」查看全文