专栏名称: 计算机视觉深度学习和自动驾驶
讨论计算机视觉、深度学习和自动驾驶的技术发展和挑战
目录
相关文章推荐
Python爱好者社区  ·  DeepSeek 被放弃了,阿里牛逼! ·  昨天  
Python爱好者社区  ·  付费上班终于成为了现实。 ·  昨天  
Python爱好者社区  ·  刚刚,DeepSeek放出重磅论文!梁文锋亲 ... ·  3 天前  
Python爱好者社区  ·  吴恩达,yyds ·  3 天前  
Python开发者  ·  马斯克 20 万 GPU ... ·  3 天前  
51好读  ›  专栏  ›  计算机视觉深度学习和自动驾驶

动手实现混合专家网络MoE

计算机视觉深度学习和自动驾驶  · 公众号  ·  · 2024-06-21 00:10

正文

前言

从大模型爆火至今,我们所熟知的大多数模型基本上都遵循了GPT的模型结构,即Decoder-only的结构,更准确的说是Casual Decoder,例如LLaMA[1],OPT[2]以及BLOOM[3]等。这类模型结构由于单向注意力机制的限制,因此训练出来的模型只能基于前文以next token prediction的方式预测后续文本从而生成文本序列。由于类似于Encoder-only的结构(例如 BERT[4]) 以及Encoder-Decoder的结构(例如 T5[5] )通过双向注意力机制可以更好的理解文本,于是有些研究人员将双向注意力机制融入到 Decoder-only的结构中,便产生了Prefix Decoder的结构,相关的优秀工作包括GLM[6],U-PaLM[7]等。然而,近段时间,另一种模型结构逐渐展示出了优越性能,就是MoE( Mixture of Experts ),例如Mistral [8] 等。其实,MoE这个概念很早之前就已经被提出,只不过最近才逐渐又火起来。准确的说, MoE是一种结构,它可以被用在 Decoder-only结构中,也可以用在 Encoder-only和Encoder-Decoder结构中。话不多说,下面将以代码的形式详细讲解 MoE结构,内容稍微有点多,希望大家耐心阅读。


MOE

MoE,全称为Mixture of Experts,翻译过来就是混合专家模型。MoE的一个显著优势是能够在远少于Dense模型所需的计算资源下进行有效的预训练。这意味着在相同的计算预算条件下,可以显著扩大模型或数据集的规模。特别是在预训练阶段,与Dense模型相比,混合专家模型通常能够更快地达到相同的质量水平。在最近使用 MoE结构的 混合专家语言模型中,大部分组件都与传统的Transformer相同。

MoE基于Transformer结构,主要由两部分组成:

  • MoE层:这些层代替了传统Transformer模型中的前馈网络 (FFN) 层。MoE层包含若干“专家”,每个专家本身是一个独立的神经网络。在实际应用中,这些专家通常是前馈网络 (FFN),但它们也可以是更复杂的网络结构。

  • 门控网络或路由: 这个部分用于决定哪些token被发送到哪个专家。例如,在下图中,“More”这个token可能被发送到第二个专家,而“Parameters”这个token被发送到第一个专家。有时,一个token甚至可以被发送到多个专家。token的路由方式是MoE使用中的一个关键点,因为路由器由学习的参数组成,并且与网络的其他部分一同进行预训练。

总结来说,在混合专家模型 (MoE) 中,将传统Transformer模型中的每个前馈网络 (FFN) 层替换为MoE层,其中MoE层由两个核心部分组成:一个路由器(或者叫门控网络)和若干数量的专家。

MoE的优点如下:

  • 训练速度更快,效果更好;

  • 相同参数,推理成本低;

  • 扩展性好,允许模型在保持计算成本不变的情况下增加参数数量,这使得它能够扩展到非常大的模型规模,如万亿参数模型;

  • 多任务学习能力,MoE在多任务学习中具备很好的新能。


MoE的缺点如下:

  • 训练稳定性,MoE在训练过程中可能会遇到稳定性问题;

  • 通信成本,在分布式训练环境中,MoE的专家路由机制可能会增加通信成本,尤其是在模型规模较大时;

  • 模型复杂性,MoE的设计相对复杂,可能需要更多的工程努力来实现和优化;

  • 下游任务性能,MoE由于其稀疏性,使得在Fine-Tuning过程中容易出现过拟合。


下面主要介绍几个模块的实现:

  • Self-Attention以及Multi-Head Self-Attention模块的实现;

  • 稀疏混合专家代替单独的前馈神经网络;

  • Top-k门控和有噪声的Top-k门控。


具体实现代码参考了Github开源项目 makeMoE ,代码链接在 https://github.com/AviSoori1x/makeMoE,本文在此基础上加上了文字解读,方便读者理解。

1. Self-Attention

常规的Self-Attention实现方式使用的缩放点积自注意力,查询矩阵、键矩阵和值矩阵都来自相同的输入序列,同时为了确保自回归语言生成过程的完整性,特别是在纯解码器模型中,使用了一种因果自注意力,也叫因果掩码。它可以掩盖当前token所处位置之后的任何信息,从而引导模型只关注序列的前面部分。值得注意的是,稀疏混合专家模型并不局限于仅有解码器的Transformer结构。事实上,这一领域的许多重要的成果都是围绕 T5结构展开的,T5 结构 也包含了Transformer模型中的编码器和解码器组件。

# 首先实现 self-attention 模块# 创建一个 [batch_size, seq_len, hidden_dim] 的张量,命名为 ximport torchimport torch.nn as nnimport torch.nn.functional as F
torch.manual_seed(42)batch_size, seq_len, n_embed = 4, 8, 32x = torch.randn(batch_size, seq_len, n_embed)
# 接下来实现一个单头的 self-attention 模块head_size = 16 # 每个头的维度key = nn.Linear(n_embed, head_size, bias=False)query = nn.Linear(n_embed, head_size, bias=False)value = nn.Linear(n_embed, head_size, bias=False)k = key(x) # (4, 8, 16)q = query(x) # (4, 8, 16)weight = q @ k.transpose(-2, -1) # (4, 8, 16) @ (4, 16, 8) ---> (4, 8, 8)
# 为了保证每个 token 只能关注到其自身以及前面几个 token,需要对 weight 进行 mask,即使用一个 [seq_len, seq_len] 的因果掩码tril = torch.tril(torch.ones(seq_len, seq_len))weight = weight.masked_fill(tril == 0, float('-inf'))weight = F.softmax(weight, dim=-1) # (4, 8, 8)
v = value(x) # (4, 8, 16)out = tril @ v # (4, 8, 8) @ (4, 8, 16) -> (4, 8, 16)# 将因果掩码 tril 打印出来即为下面的矩阵tril
triltensor([[1., 0., 0., 0., 0., 0., 0., 0.],        [1., 1., 0., 0., 0., 0., 0., 0.],        [1., 1., 1., 0., 0., 0., 0., 0.],        [1., 1., 1., 1., 0., 0., 0., 0.],        [1., 1., 1., 1., 1., 0., 0., 0.],        [1., 1., 1., 1., 1., 1., 0., 0.],        [1., 1., 1., 1., 1., 1., 1., 0.],        [1., 1., 1., 1., 1., 1., 1., 1.]])
# 接来下将相关模块整合成一个函数,方便调用
# 单头的自注意力机制class OneHead(nn.Module): """ one head of self-attention """
def __init__(self, n_embed, head_size, seq_len, dropout=0.1): super().__init__() self.key = nn.Linear(n_embed, head_size, bias=False) self.query = nn.Linear(n_embed, head_size, bias=False) self.value = nn.Linear(n_embed, head_size, bias=False) self.register_buffer('tril', torch.tril(torch.ones(seq_len, seq_len)))
self.dropout = nn.Dropout(dropout)
def forward(self, x): batch_size, seq_len, n_embed = x.shape k = self.key(x) # (batch_size, seq_len, head_size) q = self.query(x) # (batch_size, seq_len, head_size) # compute attention scores ("affinities") weight = q @ k.transpose(-2,-1) * n_embed ** -0.5 # (batch_size, seq_len, head_size) @ (batch_size, head_size, seq_len) -> (batch_size, seq_len, seq_len) weight = weight.masked_fill(self.tril[:seq_len, :seq_len] == 0, float('-inf')) # (batch_size, seq_len, seq_len) weight = F.softmax(weight, dim=-1) # (batch_size, seq_len, seq_len) weight = self.dropout(weight) # perform the weighted aggregation of the values v = self.value(x) # (batch_size, seq_len, head_size) out = weight @ v # (batch_size, seq_len, seq_len) @ (batch_size, seq_len, head_size) -> (batch_size, seq_len, head_size) return out
# 基于单头的自注意力机制实现多头的自注意力机制class MultiHeadAttention(nn.Module): """ multiple heads of self-attention in parallel """
def __init__(self, n_embed, num_heads, head_size, dropout): super().__init__() self.heads = nn.ModuleList([OneHead(head_size) for _ in range(num_heads)]) self.proj = nn.Linear(n_embed, n_embed) self.dropout = nn.Dropout(dropout)
def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out


2. 专家模块—多层感知器

在稀疏混合专家架构中,每个Transformer区块内的自注意力机制保持不变。不过,每个区块的结构发生了巨大的变化:标准的前馈神经网络被多个稀疏激活的前馈网络(即专家网络)所取代。所谓「稀疏激活」,是指序列中的每个token只被分配给有限数量的专家(通常是一个或两个)。

这有助于提高训练和推理速度,因为每次前向传递都会激活少数专家。不过,所有专家都必须存在GPU内存中,因此当参数总数达到数千亿甚至数万亿时,就会产生部署方面的问题。

# 专家模块的实现和FFN模块的实现相同class Expert(nn.Module):    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """
def __init__(self, n_embed, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(n_embed, 4 * n_embed), nn.ReLU(), nn.Linear(4 * n_embed, n_embed), nn.Dropout(dropout), )
def forward(self, x): return self.net(x)


3. Top-k门控

门控网络,也称为路由,确定哪个专家网络接收来自多头注意力的token的输出。假设有4个专家,token需要被路由到前2个专家中。首先需要通过线性层将token输入到门控网络中。 该层将对应于(batch_size, seq_len, n_embed)的输入张量从(4, 8, 32)维度,投影到对应于(batch_size, seq_len, num_expert)的新形状:(4, 8, 4)。其中n_embed是输入的通道维度,num_experts是专家网络的计数。然后,沿最后一个维度,找出最大的前两个值及其相应的索引。

# 通过一个简单的例子来理解 top-k 门控机制num_experts = 4top_k = 2batch_size = 4seq_len = 8n_embed = 32
# 假如经过 multi-head self-attention 之后,得到一个 (4, 8, 32) 的张量mh_output = torch.randn(batch_size, seq_len, n_embed)
topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)
logits = topkgate_linear(mh_output)top_k_logits, top_k_indices = logits.topk(top_k, dim=-1) # Get top-k experts# 打印 top-k 门控的值以及索引top_k_logits, top_k_indices
(tensor([[[ 1.0700,  0.7206],          [ 0.2494, -0.0838],          [ 1.2022,  0.5802],          [ 0.8623,  0.6392],          [ 0.3154,  0.0610],          [ 0.8664,  0.6319],          [ 0.5692,  0.0469],          [ 1.3120, -0.3133]],
[[ 1.2228, -0.0321], [ 1.1416, 0.3027], [ 0.5253, 0.4374], [ 0.1580, -0.0446], [ 0.3139, 0.2930], [ 0.3529, 0.2312], [ 1.4150, 0.2912], [ 0.5945, 0.1327]],
[[ 0.5750, -0.0629], [ 0.6928, 0.2333], [ 0.6365, 0.2649], [ 0.4032, 0.1236], [ 0.8245, -0.1826], [ 1.3292, 0.2458], [-0.0589, -0.0794],... [1, 2], [3, 2], [3, 2], [1, 3], [2, 1]]]))

通过仅保留沿最后一个维度进行比较的前k大的值,来获得稀疏门控的输出。用负无穷值填充其余部分,在使用softmax激活函数。负无穷会被映射至零,而最大的前两个值会更加突出,且和为1。要求和为1是为了对专家输出的内容进行加权。

zeros = torch.full_like(logits, float('-inf')) sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)sparse_logits
tensor([[[   -inf,  0.2479,    -inf,  0.9578],         [-0.4220, -0.4273,    -inf,    -inf],         [ 0.1390,    -inf,    -inf,  0.1366],         [   -inf,  0.2133,    -inf,  0.2477],         [   -inf,  0.4563,  0.7880,    -inf],         [ 0.1496,    -inf,    -inf, -0.0354],         [-0.0292,    -inf,  1.0569,    -inf],         [ 0.5379,    -inf,  0.1495,    -inf]],
[[ 0.5540, -inf, 0.2338, -inf], [ 0.2380, -inf, 0.8278, -inf], [ -inf, 0.1439, 0.6391, -inf], [ 0.2268, 0.4964, -inf, -inf], [ 0.1168, -0.1425, -inf, -inf], [ 0.3889, 0.4419, -inf, -inf], [ 1.0024, -inf, 0.0241, -inf], [ -inf, -0.1975, 0.5915, -inf]],
[[ 0.6596, -inf, -inf, 0.7077], [ -inf, -inf, 0.6383, 0.7035], [ -inf, -inf, 0.2839, 1.1395], [ 0.2079, -inf, 0.5646, -inf], [ -inf, 0.0644, 0.8486, -inf], [ 0.1942, -inf, -inf, -0.0916], [ -inf, -inf, 1.0266, 0.3353],... [ -inf, 0.1115, -inf, 0.5310], [ -inf, 0.1746, 0.6859, -inf], [ -inf, 0.1188, 0.9117, -inf ], [ 0.4967, -inf, 0.4524, -inf], [ 0.4938, 0.4175, -inf, -inf]]], grad_fn=)
gating_output= F.softmax(sparse_logits, dim=-1)gating_output
tensor([[[0.0000, 0.3296, 0.0000, 0.6704],         [0.5013, 0.4987, 0.0000, 0.0000],         [0.5006, 0.0000, 0.0000, 0.4994],         [0.0000, 0.4914, 0.0000, 0.5086],         [0.0000, 0.4178, 0.5822, 0.0000],         [0.5461, 0.0000, 0.0000, 0.4539],         [0.2523, 0.0000, 0.7477, 0.0000],         [0.5959, 0.0000, 0.4041, 0.0000]],
[[0.5794, 0.0000, 0.4206, 0.0000], [0.3567, 0.0000, 0.6433, 0.0000], [0.0000, 0.3787, 0.6213, 0.0000], [0.4330, 0.5670, 0.0000, 0.0000], [0.5645, 0.4355, 0.0000, 0.0000], [0.4868, 0.5132, 0.0000, 0.0000], [0.7268, 0.0000, 0.2732, 0.0000], [0.0000, 0.3124, 0.6876, 0.0000]],
[[0.4880, 0.0000, 0.0000, 0.5120], [0.0000, 0.0000, 0.4837, 0.5163], [0.0000, 0.0000, 0.2983, 0.7017], [0.4118, 0.0000, 0.5882, 0.0000], [0.0000, 0.3134, 0.6866, 0.0000], [0.5710, 0.0000, 0.0000, 0.4290], [0.0000, 0.0000, 0.6663, 0.3337],... [0.0000, 0.3966, 0.0000, 0.6034], [0.0000, 0.3749, 0.6251, 0.0000], [0.0000, 0.3116, 0.6884, 0.0000], [0.5111, 0.0000, 0.4889, 0.0000], [0.5191, 0.4809, 0.0000, 0.0000]]], grad_fn=)
# 将 top-k 门控机制整理成一个函数class TopkRouter(nn.Module):    def __init__(self, n_embed, num_experts, top_k):        super(TopkRouter, self).__init__()        self.top_k = top_k        self.linear =nn.Linear(n_embed, num_experts)
def forward(self, mh_output): # mh_ouput is the output tensor from multi-head self attention block logits = self.linear(mh_output) top_k_logits, indices = logits.topk(self.top_k, dim=-1) zeros = torch.full_like(logits, float('-inf')) sparse_logits = zeros.scatter(-1, indices, top_k_logits) router_output = F.softmax(sparse_logits, dim=-1) return router_output, indices
# 测试用例num_experts = 4top_k = 2batch_size = 4seq_len = 8n_embed = 32
mh_output = torch.randn(batch_size, seq_len, n_embed)top_k_gate = TopkRouter(n_embed, num_experts, top_k)gating_output, indices = top_k_gate(mh_output)gating_output.shape, gating_output, indices
(torch.Size([4, 8, 4]), tensor([[[0.4262, 0.0000, 0.0000, 0.5738],          [0.0000, 0.3942, 0.0000, 0.6058],          [0.0000, 0.5768, 0.4232, 0.0000],          [0.3226, 0.0000, 0.0000, 0.6774],          [0.3806, 0.0000, 0.6194, 0.0000],          [0.7507, 0.0000, 0.2493, 0.0000],          [0.0000, 0.0000, 0.4478, 0.5522],          [0.0000, 0.0000, 0.3864, 0.6136]],
[[0.2937, 0.7063, 0.0000, 0.0000], [0.3889, 0.6111, 0.0000, 0.0000], [0.0000, 0.5416, 0.4584, 0.0000], [0.0000, 0.0000, 0.4122, 0.5878], [0.3586, 0.0000, 0.6414, 0.0000], [0.2137, 0.0000, 0.0000, 0.7863], [0.3996, 0.0000, 0.6004, 0.0000], [0.0000, 0.6880, 0.0000, 0.3120]],
[[0.5106, 0.4894, 0.0000, 0.0000], [0.0000, 0.5588, 0.4412, 0.0000], [0.0000, 0.3291, 0.6709, 0.0000], [0.0000, 0.5787, 0.0000, 0.4213], [0.5344, 0.0000, 0.0000, 0.4656], [0.3891, 0.0000, 0.6109, 0.0000],... [3, 0], [1, 0], [1, 0], [2, 0], [1, 2]]]))


4. 有噪声的Top-k门控—实现负载平衡

有噪声的Top-k门控机制是训练MoE模型的一个重要工具。从本质上讲,不会希望所有的token都发送给同一组「受欢迎」的专家网络。人们需要的是能在开发和探索之间取得良好平衡。为此,为了负载平衡,从门控的线性层向logits激活函数添加标准正态噪声是有帮助的,这使训练更有效率。

class NoisyTopkRouter(nn.Module):    def __init__(self, n_embed, num_experts, top_k):        super(NoisyTopkRouter, self).__init__()        self.top_k = top_k        # layer for router logits        self.topkroute_linear = nn.Linear(n_embed, num_experts)        self.noise_linear =nn.Linear(n_embed, num_experts)
def forward(self, mh_output): # mh_output is the output tensor from multihead self attention block logits = self.topkroute_linear(mh_output)
# Noise logits noise_logits = self.noise_linear(mh_output)
# Adding scaled unit gaussian noise to the logits noise = torch.randn_like(logits)*F.softplus(noise_logits) noisy_logits = logits + noise
top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1) zeros = torch.full_like(noisy_logits, float('-inf')) sparse_logits = zeros.scatter(-1, indices, top_k_logits) router_output = F.softmax(sparse_logits, dim=-1) return router_output, indices
# 测试用例num_experts = 8top_k = 2batch_size = 4seq_len = 8n_embed = 32
mh_output = torch.randn(batch_size, seq_len, n_embed)noisy_top_k_gate = NoisyTopkRouter(n_embed, num_experts, top_k)gating_output, indices = noisy_top_k_gate(mh_output)gating_output.shape, gating_output, indices
(torch.Size([4, 8, 8]), tensor([[[0.0000, 0.5102, 0.0000, 0.0000, 0.0000, 0.4898, 0.0000, 0.0000],          [0.4597, 0.0000, 0.0000, 0.5403, 0.0000, 0.0000, 0.0000, 0.0000],          [0.0000, 0.0000, 0.0000, 0.5533, 0.0000, 0.4467, 0.0000, 0.0000],          [0.0000, 0.5424, 0.0000, 0.4576, 0.0000, 0.0000, 0.0000, 0.0000],          [0.0000, 0.0710, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9290],          [0.0000, 0.0000, 0.9008, 0.0000, 0.0000, 0.0992, 0.0000, 0.0000],          [0.7085, 0.0000, 0.0000, 0.0000, 0.2915, 0.0000, 0.0000, 0.0000],          [0.3907, 0.0000, 0.6093, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.3858, 0.0000, 0.0000, 0.6142, 0.0000, 0.0000], [0.5314, 0.0000, 0.0000, 0.4686, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2911, 0.0000, 0.7089], [0.0000, 0.0000, 0.0000, 0.6592, 0.0000, 0.0000, 0.3408, 0.0000], [0.0000, 0.0000, 0.0000, 0.4964, 0.0000, 0.0000, 0.5036, 0.0000], [0.4978, 0.0000, 0.5022, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5174, 0.0000, 0.0000, 0.0000, 0.4826, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6100, 0.0000, 0.3900]],
[[0.0000, 0.0000, 0.0000, 0.9672, 0.0000, 0.0000, 0.0000, 0.0328], [0.5358, 0.0000, 0.4642, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.3875, 0.0000, 0.0000, 0.6125, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6422, 0.0000, 0.3578], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6100, 0.3900], [0.0000, 0.0000, 0.4177, 0.5823, 0.0000, 0.0000, 0.0000, 0.0000],... [2, 0], [5, 2], [6, 5], [2, 5], [2, 0]]]))

5.稀疏化的混合专家模块

在获得门控网络的输出结果之后,对于给定的token,将前k个值选择性地与来自相应的前k个专家的输出相乘。这种选择性乘法的结果是一个加权和,该加权和构成SparseMoe模块的输出。这个过程的关键和难点是避免不必要的乘法运算,只为前k名专家进行正向转播。为每个专家执行前向传播将破坏使用稀疏MoE的目的,因为这个过程将不再是稀疏的。

class SparseMoE(nn.Module):    def __init__(self, n_embed, num_experts, top_k, dropout):        super(SparseMoE, self).__init__()        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)        self.experts = nn.ModuleList([Expert(n_embed, dropout) for _ in range(num_experts)])        self.top_k = top_k
def forward(self, x): gating_output, indices = self.router(x) final_output = torch.zeros_like(x)
# Reshape inputs for batch processing flat_x = x.view(-1, x.size(-1)) flat_gating_output = gating_output.view(-1, gating_output.size(-1))
# Process each expert in parallel for i, expert in enumerate(self.experts): # Create a mask for the inputs where the current expert is in top-k expert_mask = (indices == i).any(dim=-1) flat_mask = expert_mask.view(-1)
if flat_mask.any(): expert_input = flat_x[flat_mask] expert_output = expert(expert_input)






请到「今天看啥」查看全文