前言
从大模型爆火至今,我们所熟知的大多数模型基本上都遵循了GPT的模型结构,即Decoder-only的结构,更准确的说是Casual Decoder,例如LLaMA[1],OPT[2]以及BLOOM[3]等。这类模型结构由于单向注意力机制的限制,因此训练出来的模型只能基于前文以next token prediction的方式预测后续文本从而生成文本序列。由于类似于Encoder-only的结构(例如
BERT[4])
以及Encoder-Decoder的结构(例如
T5[5]
)通过双向注意力机制可以更好的理解文本,于是有些研究人员将双向注意力机制融入到
Decoder-only的结构中,便产生了Prefix Decoder的结构,相关的优秀工作包括GLM[6],U-PaLM[7]等。然而,近段时间,另一种模型结构逐渐展示出了优越性能,就是MoE(
Mixture of Experts
),例如Mistral
[8]
等。其实,MoE这个概念很早之前就已经被提出,只不过最近才逐渐又火起来。准确的说,
MoE是一种结构,它可以被用在
Decoder-only结构中,也可以用在
Encoder-only和Encoder-Decoder结构中。话不多说,下面将以代码的形式详细讲解
MoE结构,内容稍微有点多,希望大家耐心阅读。
MOE
MoE,全称为Mixture of Experts,翻译过来就是混合专家模型。MoE的一个显著优势是能够在远少于Dense模型所需的计算资源下进行有效的预训练。这意味着在相同的计算预算条件下,可以显著扩大模型或数据集的规模。特别是在预训练阶段,与Dense模型相比,混合专家模型通常能够更快地达到相同的质量水平。在最近使用
MoE结构的
混合专家语言模型中,大部分组件都与传统的Transformer相同。
MoE基于Transformer结构,主要由两部分组成:
-
MoE层:这些层代替了传统Transformer模型中的前馈网络 (FFN) 层。MoE层包含若干“专家”,每个专家本身是一个独立的神经网络。在实际应用中,这些专家通常是前馈网络 (FFN),但它们也可以是更复杂的网络结构。
-
门控网络或路由: 这个部分用于决定哪些token被发送到哪个专家。例如,在下图中,“More”这个token可能被发送到第二个专家,而“Parameters”这个token被发送到第一个专家。有时,一个token甚至可以被发送到多个专家。token的路由方式是MoE使用中的一个关键点,因为路由器由学习的参数组成,并且与网络的其他部分一同进行预训练。
总结来说,在混合专家模型 (MoE) 中,将传统Transformer模型中的每个前馈网络 (FFN) 层替换为MoE层,其中MoE层由两个核心部分组成:一个路由器(或者叫门控网络)和若干数量的专家。
MoE的优点如下:
MoE的缺点如下:
-
训练稳定性,MoE在训练过程中可能会遇到稳定性问题;
-
通信成本,在分布式训练环境中,MoE的专家路由机制可能会增加通信成本,尤其是在模型规模较大时;
-
模型复杂性,MoE的设计相对复杂,可能需要更多的工程努力来实现和优化;
-
下游任务性能,MoE由于其稀疏性,使得在Fine-Tuning过程中容易出现过拟合。
下面主要介绍几个模块的实现:
具体实现代码参考了Github开源项目
makeMoE
,代码链接在
https://github.com/AviSoori1x/makeMoE,本文在此基础上加上了文字解读,方便读者理解。
1.
Self-Attention
常规的Self-Attention实现方式使用的缩放点积自注意力,查询矩阵、键矩阵和值矩阵都来自相同的输入序列,同时为了确保自回归语言生成过程的完整性,特别是在纯解码器模型中,使用了一种因果自注意力,也叫因果掩码。它可以掩盖当前token所处位置之后的任何信息,从而引导模型只关注序列的前面部分。值得注意的是,稀疏混合专家模型并不局限于仅有解码器的Transformer结构。事实上,这一领域的许多重要的成果都是围绕 T5结构展开的,T5
结构
也包含了Transformer模型中的编码器和解码器组件。
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)
batch_size, seq_len, n_embed = 4, 8, 32
x = torch.randn(batch_size, seq_len, n_embed)
head_size = 16
key = nn.Linear(n_embed, head_size, bias=False)
query = nn.Linear(n_embed, head_size, bias=False)
value = nn.Linear(n_embed, head_size, bias=False)
k = key(x)
q = query(x)
weight = q @ k.transpose(-2, -1)
tril = torch.tril(torch.ones(seq_len, seq_len))
weight = weight.masked_fill(tril == 0, float('-inf'))
weight = F.softmax(weight, dim=-1)
v = value(x)
out = tril @ v
tril
tril:
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.]])
class OneHead(nn.Module):
""" one head of self-attention """
def __init__(self, n_embed, head_size, seq_len, dropout=0.1):
super().__init__()
self.key = nn.Linear(n_embed, head_size, bias=False)
self.query = nn.Linear(n_embed, head_size, bias=False)
self.value = nn.Linear(n_embed, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(seq_len, seq_len)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
batch_size, seq_len, n_embed = x.shape
k = self.key(x)
q = self.query(x)
weight = q @ k.transpose(-2,-1) * n_embed ** -0.5
weight = weight.masked_fill(self.tril[:seq_len, :seq_len] == 0, float('-inf'))
weight = F.softmax(weight, dim=-1)
weight = self.dropout(weight)
v = self.value(x)
out = weight @ v
return
out
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel """
def __init__(self, n_embed, num_heads, head_size, dropout):
super().__init__()
self.heads = nn.ModuleList([OneHead(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embed, n_embed)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
2. 专家模块—多层感知器
在稀疏混合专家架构中,每个Transformer区块内的自注意力机制保持不变。不过,每个区块的结构发生了巨大的变化:标准的前馈神经网络被多个稀疏激活的前馈网络(即专家网络)所取代。所谓「稀疏激活」,是指序列中的每个token只被分配给有限数量的专家(通常是一个或两个)。
这有助于提高训练和推理速度,因为每次前向传递都会激活少数专家。不过,所有专家都必须存在GPU内存中,因此当参数总数达到数千亿甚至数万亿时,就会产生部署方面的问题。
# 专家模块的实现和FFN模块的实现相同
class Expert(nn.Module):
""" An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """
def __init__(self, n_embed, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embed, 4 * n_embed),
nn.ReLU(),
nn.Linear(4 * n_embed, n_embed),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
3.
Top-k门控
门控网络,也称为路由,确定哪个专家网络接收来自多头注意力的token的输出。假设有4个专家,token需要被路由到前2个专家中。首先需要通过线性层将token输入到门控网络中。
该层将对应于(batch_size, seq_len, n_embed)的输入张量从(4, 8, 32)维度,投影到对应于(batch_size, seq_len, num_expert)的新形状:(4, 8, 4)。其中n_embed是输入的通道维度,num_experts是专家网络的计数。然后,沿最后一个维度,找出最大的前两个值及其相应的索引。
num_experts = 4
top_k = 2
batch_size = 4
seq_len = 8
n_embed = 32
mh_output = torch.randn(batch_size, seq_len, n_embed)
topkgate_linear = nn.Linear(n_embed, num_experts)
logits = topkgate_linear(mh_output)
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)
top_k_logits, top_k_indices
(tensor([[[ 1.0700, 0.7206],
[ 0.2494, -0.0838],
[ 1.2022, 0.5802],
[ 0.8623, 0.6392],
[ 0.3154, 0.0610],
[ 0.8664, 0.6319],
[ 0.5692, 0.0469],
[ 1.3120, -0.3133]],
[[ 1.2228, -0.0321],
[ 1.1416, 0.3027],
[ 0.5253, 0.4374],
[ 0.1580, -0.0446],
[ 0.3139, 0.2930],
[ 0.3529, 0.2312],
[ 1.4150, 0.2912],
[ 0.5945, 0.1327]],
[[ 0.5750, -0.0629],
[ 0.6928, 0.2333],
[ 0.6365, 0.2649],
[ 0.4032, 0.1236],
[ 0.8245, -0.1826],
[ 1.3292, 0.2458],
[-0.0589, -0.0794],
...
[1, 2],
[3, 2],
[3, 2],
[1, 3],
[2, 1]]]))
通过仅保留沿最后一个维度进行比较的前k大的值,来获得稀疏门控的输出。用负无穷值填充其余部分,在使用softmax激活函数。负无穷会被映射至零,而最大的前两个值会更加突出,且和为1。要求和为1是为了对专家输出的内容进行加权。
zeros = torch.full_like(logits, float('-inf'))
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits
tensor([[[ -inf, 0.2479, -inf, 0.9578],
[-0.4220, -0.4273, -inf, -inf],
[ 0.1390, -inf, -inf, 0.1366],
[ -inf, 0.2133, -inf, 0.2477],
[ -inf, 0.4563, 0.7880, -inf],
[ 0.1496, -inf, -inf, -0.0354],
[-0.0292, -inf, 1.0569, -inf],
[ 0.5379, -inf, 0.1495, -inf]],
[[ 0.5540, -inf, 0.2338, -inf],
[ 0.2380, -inf, 0.8278, -inf],
[ -inf, 0.1439, 0.6391, -inf],
[ 0.2268, 0.4964, -inf, -inf],
[ 0.1168, -0.1425, -inf, -inf],
[ 0.3889, 0.4419, -inf, -inf],
[ 1.0024, -inf, 0.0241, -inf],
[ -inf, -0.1975, 0.5915, -inf]],
[[ 0.6596, -inf, -inf, 0.7077],
[ -inf, -inf, 0.6383, 0.7035],
[ -inf, -inf, 0.2839, 1.1395],
[ 0.2079, -inf, 0.5646, -inf],
[ -inf, 0.0644, 0.8486, -inf],
[ 0.1942, -inf, -inf, -0.0916],
[ -inf, -inf, 1.0266, 0.3353],
...
[ -inf, 0.1115, -inf, 0.5310],
[ -inf, 0.1746, 0.6859, -inf],
[ -inf, 0.1188, 0.9117, -inf
],
[ 0.4967, -inf, 0.4524, -inf],
[ 0.4938, 0.4175, -inf, -inf]]], grad_fn=)
gating_output= F.softmax(sparse_logits, dim=-1)
gating_output
tensor([[[0.0000, 0.3296, 0.0000, 0.6704],
[0.5013, 0.4987, 0.0000, 0.0000],
[0.5006, 0.0000, 0.0000, 0.4994],
[0.0000, 0.4914, 0.0000, 0.5086],
[0.0000, 0.4178, 0.5822, 0.0000],
[0.5461, 0.0000, 0.0000, 0.4539],
[0.2523, 0.0000, 0.7477, 0.0000],
[0.5959, 0.0000, 0.4041, 0.0000]],
[[0.5794, 0.0000, 0.4206, 0.0000],
[0.3567, 0.0000, 0.6433, 0.0000],
[0.0000, 0.3787, 0.6213, 0.0000],
[0.4330, 0.5670, 0.0000, 0.0000],
[0.5645, 0.4355, 0.0000, 0.0000],
[0.4868, 0.5132, 0.0000, 0.0000],
[0.7268, 0.0000, 0.2732, 0.0000],
[0.0000, 0.3124, 0.6876, 0.0000]],
[[0.4880, 0.0000, 0.0000, 0.5120],
[0.0000, 0.0000, 0.4837, 0.5163],
[0.0000, 0.0000, 0.2983, 0.7017],
[0.4118, 0.0000, 0.5882, 0.0000],
[0.0000, 0.3134, 0.6866, 0.0000],
[0.5710, 0.0000, 0.0000, 0.4290],
[0.0000, 0.0000, 0.6663, 0.3337],
...
[0.0000, 0.3966, 0.0000, 0.6034],
[0.0000, 0.3749, 0.6251, 0.0000],
[0.0000, 0.3116, 0.6884, 0.0000],
[0.5111, 0.0000, 0.4889, 0.0000],
[0.5191, 0.4809, 0.0000, 0.0000]]], grad_fn=)
class TopkRouter(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(TopkRouter, self).__init__()
self.top_k = top_k
self.linear =nn.Linear(n_embed, num_experts)
def forward(self, mh_output):
logits = self.linear(mh_output)
top_k_logits, indices = logits.topk(self.top_k, dim=-1)
zeros = torch.full_like(logits, float('-inf'))
sparse_logits = zeros.scatter(-1, indices, top_k_logits)
router_output = F.softmax(sparse_logits, dim=-1)
return router_output, indices
num_experts = 4
top_k = 2
batch_size = 4
seq_len = 8
n_embed = 32
mh_output = torch.randn(batch_size, seq_len, n_embed)
top_k_gate = TopkRouter(n_embed, num_experts, top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices
(torch.Size([4, 8, 4]),
tensor([[[0.4262, 0.0000, 0.0000, 0.5738],
[0.0000, 0.3942, 0.0000, 0.6058],
[0.0000, 0.5768, 0.4232, 0.0000],
[0.3226, 0.0000, 0.0000, 0.6774],
[0.3806, 0.0000, 0.6194, 0.0000],
[0.7507, 0.0000, 0.2493, 0.0000],
[0.0000, 0.0000, 0.4478, 0.5522],
[0.0000, 0.0000, 0.3864, 0.6136]],
[[0.2937, 0.7063, 0.0000, 0.0000],
[0.3889, 0.6111, 0.0000, 0.0000],
[0.0000, 0.5416, 0.4584, 0.0000],
[0.0000, 0.0000, 0.4122, 0.5878],
[0.3586, 0.0000, 0.6414, 0.0000],
[0.2137, 0.0000, 0.0000, 0.7863],
[0.3996, 0.0000, 0.6004, 0.0000],
[0.0000, 0.6880, 0.0000, 0.3120]],
[[0.5106, 0.4894, 0.0000, 0.0000],
[0.0000, 0.5588, 0.4412, 0.0000],
[0.0000, 0.3291, 0.6709, 0.0000],
[0.0000, 0.5787, 0.0000, 0.4213],
[0.5344, 0.0000, 0.0000, 0.4656],
[0.3891, 0.0000, 0.6109, 0.0000],
...
[3, 0],
[1, 0],
[1, 0],
[2, 0],
[1, 2]]]))
4.
有噪声的Top-k门控—实现负载平衡
有噪声的Top-k门控机制是训练MoE模型的一个重要工具。从本质上讲,不会希望所有的token都发送给同一组「受欢迎」的专家网络。人们需要的是能在开发和探索之间取得良好平衡。为此,为了负载平衡,从门控的线性层向logits激活函数添加标准正态噪声是有帮助的,这使训练更有效率。
class NoisyTopkRouter(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(NoisyTopkRouter, self).__init__()
self.top_k = top_k
self.topkroute_linear = nn.Linear(n_embed, num_experts)
self.noise_linear =nn.Linear(n_embed, num_experts)
def forward(self, mh_output):
logits = self.topkroute_linear(mh_output)
noise_logits = self.noise_linear(mh_output)
noise = torch.randn_like(logits)*F.softplus(noise_logits)
noisy_logits = logits + noise
top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
zeros = torch.full_like(noisy_logits, float('-inf'))
sparse_logits = zeros.scatter(-1, indices, top_k_logits)
router_output = F.softmax(sparse_logits, dim=-1)
return router_output, indices
# 测试用例
num_experts = 8
top_k = 2
batch_size = 4
seq_len = 8
n_embed = 32
mh_output = torch.randn(batch_size, seq_len, n_embed)
noisy_top_k_gate = NoisyTopkRouter(n_embed, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output, indices
(torch.Size([4, 8, 8]),
tensor([[[0.0000, 0.5102, 0.0000, 0.0000, 0.0000, 0.4898, 0.0000, 0.0000],
[0.4597, 0.0000, 0.0000, 0.5403, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.5533, 0.0000, 0.4467, 0.0000, 0.0000],
[0.0000, 0.5424, 0.0000, 0.4576, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0710, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9290],
[0.0000, 0.0000, 0.9008, 0.0000, 0.0000, 0.0992, 0.0000, 0.0000],
[0.7085, 0.0000, 0.0000, 0.0000, 0.2915, 0.0000, 0.0000, 0.0000],
[0.3907, 0.0000, 0.6093, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.3858, 0.0000, 0.0000, 0.6142, 0.0000, 0.0000],
[0.5314, 0.0000, 0.0000, 0.4686, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2911, 0.0000, 0.7089],
[0.0000, 0.0000, 0.0000, 0.6592, 0.0000, 0.0000, 0.3408, 0.0000],
[0.0000, 0.0000, 0.0000, 0.4964, 0.0000, 0.0000, 0.5036, 0.0000],
[0.4978, 0.0000, 0.5022, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5174, 0.0000, 0.0000, 0.0000, 0.4826, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6100, 0.0000, 0.3900]],
[[0.0000, 0.0000, 0.0000, 0.9672, 0.0000, 0.0000, 0.0000, 0.0328],
[0.5358, 0.0000, 0.4642, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.3875, 0.0000, 0.0000, 0.6125, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6422, 0.0000, 0.3578],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6100, 0.3900],
[0.0000, 0.0000, 0.4177, 0.5823, 0.0000, 0.0000, 0.0000, 0.0000],
...
[2, 0],
[5, 2],
[6, 5],
[2, 5],
[2, 0]]]))
5.稀疏化的混合专家模块
在获得门控网络的输出结果之后,对于给定的token,将前k个值选择性地与来自相应的前k个专家的输出相乘。这种选择性乘法的结果是一个加权和,该加权和构成SparseMoe模块的输出。这个过程的关键和难点是避免不必要的乘法运算,只为前k名专家进行正向转播。为每个专家执行前向传播将破坏使用稀疏MoE的目的,因为这个过程将不再是稀疏的。
class SparseMoE(nn.Module):
def __init__(self, n_embed, num_experts, top_k, dropout):
super(SparseMoE, self).__init__()
self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
self.experts = nn.ModuleList([Expert(n_embed, dropout) for _ in range(num_experts)])
self.top_k = top_k
def forward(self, x):
gating_output, indices = self.router(x)
final_output = torch.zeros_like(x)
flat_x = x.view(-1, x.size(-1))
flat_gating_output = gating_output.view(-1, gating_output.size(-1))
for i, expert in enumerate(self.experts):
expert_mask = (indices == i).any(dim=-1)
flat_mask = expert_mask.view(-1)
if flat_mask.any():
expert_input = flat_x[flat_mask]
expert_output = expert(expert_input)