24年5月来自香港中文大学深圳分校和其他几个研究机构(包括浙江大学和西湖大学)的论文“Dynamic Mixture of Experts: An Auto-tuning Approach for Efficient Transformer Models”。
稀疏混合专家 (SMoE) 已被广泛用于提高基于 Transformer 基础模型的训练和推理效率,并取得了令人欣喜的效果。然而,SMoE 的性能在很大程度上取决于超参的选择,例如专家数和要激活的专家数(称为 top-k),由于通过搜索各种超参配置进行大量模型训练,导致计算开销巨大。为了解决这个问题,引入了 动态混合专家 (DYN-MOE) 技术。DYN-MOE 结合了 (1) 一种新门控方法,使每个 token 能够自动确定要激活的专家数。(2) 自适应过程会在训练期间自动调优专家数。视觉、语言和视觉-语言任务的大量数值结果表明,该方法在视觉和语言任务上的表现优于 GMoE(论文“ Sparse mixture-of-experts are domain generalizable learners ”),在视觉-语言任务上优于 MoE-LLaVA,同时通过激活更少的参数保持效率。
代码开源在 https://github.com/LINs-lab/DynMoE 。
动态混合专家结合两个关键组件:
(1) top-any门控方法(如图所示:输入 token 经过与每个专家 e 对应的门控权重 Wge,得到门控分 Ge。然后将这些门控得分与门控 Ge 进行比较,确定是否激活后续专家。最后,将专家输出组合起来,产生输出 token),将门控机制建模为多标签分类问题,允许 token 自行决定要激活的专家数。这使得不同的 token 可以激活不同数的专家,包括不激活任何专家的零选项。
(2) 精心设计的自适应过程,当 token 选择不激活任何现有专家时,它会添加新专家,并删除任何未被一个 token 激活的多余专家。
整个过程总结在如下算法中:
传统的 top-k 门控方法使用 token 嵌入 x 作为输入,并使用额外的门控网络 g 来预测输入 token 嵌入分配给每个专家的分数。通常,给定 token x 作为输入,门控过程定义如下:
其中 Wg 是门控网络的参数,K 是专家数。MoE 层的输出定义为:
其中 Ee(x) 是给定输入 x 时第 e 个 专家的输出,g(x)e 是 g(x) 的第 e 项。
尽管 top-k 门控方法在提高训练和推理效率方面取得了相当大的成功,但仍存在两个限制:
1. 必须对 k 值进行微调以优化模型性能。如上图所示,MoE 模型的性能会随着不同的 top-k 值而发生显著变化。最近的研究也注意到了这一观察结果 [6, 12, 53]。因此,需要大量的计算资源来确定 k 的最佳值。
2. top-k 门控方法假设每个 token 必须激活相同数的专家,但在实践中可能并不总是如此。例如,在考虑不同的任务时,可能存在所有任务共享的 token 和特定于某些任务的 token,即不同的 token 可以激活不同数的专家。
通过无需调整的 top-any 门控方法解决 top-k 门控的局限性 。为了解决上述局限性, top-any 门控不需要预定义k 值,允许不同的 token 在训练和推理阶段激活不同数的专家。
top-any 门控方法的设计灵感来自多标签分类问题。将每个专家视为一个单独类,并独立计算每个类(专家)的分类(门控)分数。随后,所有分数超过阈值的类(专家)都被视为positive 类(激活)。具体来说,给定专家表示矩阵 Wg,其中 Wg 的第 k 行作为专家 k 的表示,以及输入token x,top-any 门控的关键步骤可以通过以下等式来表示:
先计算token和专家表征矩阵Wg之间的cosine相似性,得到分数s(x)。然后应用sigmoid函数σ在相似度得分 s(x),得到 0 到 1 之间的得分。最后,相似度得分大于可训练阈值 G 的专家将被视为对 token x 的激活专家。需要注意的是,符号函数不支持反向传播,因此定制这部分的反向传播过程,即直接将 g(x) 的梯度复制到 σ (s(x)) − σ(G),以有效绕过符号函数。
给定门控得分g(x) ,那么激活的专家数目计算如下:
其中 k 表示针对 token x 激活的专家数。采用 top-any 门控方法的 MoE 层模型输出可以推导出如下:
在测试中改进top-any门控防止token丢失 。为了便于设计自适应专家数过程,没有对 k 施加最小值。因此,某些 token 可能不会激活任何专家。为了解决这个问题,在模型性能评估期间,修改top-any 门控,对不选择激活任何专家的 token 启用 top-1 门控。具体来说,对于 sum(g(x)) = 0 的输入 token x,修改其门控得分如下:
通过辅助损失确保 top-any 门控的效率 。使用 MoE 模型的主要目标是提高训练和推理效率。然而,在没有对激活专家的最大数进行限制的情况下,token 可能会激活所有专家,这与主要的目标背道而驰。
使用辅助损失作为专家的正则化可能会缓解这个问题。然而,现有的辅助损失方法 [24, 13, 51] 主要用于确保专家之间的负载平衡,因此无法与目标保持一致。虽然激活所有专家确实可以实现负载平衡,但这与限制激活专家数来提高效率的目标相矛盾。因此,需要一个不仅能确保负载平衡,还能限制激活专家数的解决方案。
作为补救措施,提出一种新的辅助损失,即稀疏和简单门控损失。多样性损失和简单性损失共同通过解决专家表示的不同方面来提高模型效率。一方面, 多样性损失 鼓励各个专家Wg 表示之间的独立性。它有两个作用:一是防止专家之间过于相似,从而提升模型的表征能力;二是引导 token 避免所有专家同时激活,从而促进稀疏门控提高效率。另一方面, 简单性损失 对 Wg 进行归一化,避免矩阵各个值过大,这有助于保持数值稳定性,并防止由于参数值过大而导致过拟合。详细的损失函数定义如下:
自适应训练进程
自适应训练过程旨在自动确定专家数。如图所示,自适应过程由三个部分组成,即 (1) 路由记录:记录训练期间的路由结果;(2) 添加专家:当 token 选择不激活任何现有专家时添加新专家;(3) 删除专家:删除未被任何 token 选择的专家。
路由记录 。为了方便删除和添加专家,跟踪路由状态至关重要。具体来说,为每个 MoE 层记录两个关键信息:(1) 对于每个专家 e,记录专家 e 被激活的时间,表示为 RE(如算法第 9 行所示)。(2) 对于未激活任何专家的输入数据,计算它们的嵌入 x 总和为 RS (如算法第 10 行所述)。请注意,这种方法简化专家添加过程:通过使用token嵌入来初始化专家表示 Wg ,可以在这些tokens和新专家之间实现较高的相似度得分,从而确保新专家在添加时将被这些tokens激活。
如算法所示,利用flags和 flagf 来确定何时开始和停止路由记录。用户可以根据需要控制这两个标志。
当存在选择不激活任何专家的tokens时添加专家 。当记录的 RS ̸= 0 时,添加新专家,因为有些tokens不会激活任何专家,而 RS 是这些tokens的总和。因此,给定 K 个激活的专家和新专家 K + 1,初始化 Wg,K+1 = RS/∥RS∥ 和 GK+1 = 0。
当存在未被任何tokens激活的专家时删除专家。当存在专家 e 使得 ReE = 0 时,会删除专家(如算法中的第 13 行所示)。
实验设置
实验的讨论问题:
• Q1:DYN-MOE 能否在不同的 MoE 设置中实现具有竞争力的性能?
• Q2:DYN-MOE 能否处理具有不同模态和规模的任务?
• Q3:DYN-MOE 训练的模型是否会保持稀疏性以确保效率?
• Q4:DYN-MOE 能否提供可以指导 MoE 模型设计的见解?
为了回答上述四个问题,对视觉、语言和视觉-语言任务进行实验。详细信息如下所示。
• 视觉任务 。对于视觉任务,遵循与 GMoE [26] 相同的设置。用预先训练的 ViT-S/16 [10] 模型并在 DomainBed [16] 基准上对其进行评估。实验涵盖四个领域泛化数据集:PACS [27]、VLCS [2]、OfficeHome [48] 和 DomainNet [36]。所有结果均使用训练-验证选择标准报告。
• 语言任务 。语言任务遵循与 MoEfication [56] 和 EMoE [38] 相同的设置。MoE 模型基于 BERT-large [8] 架构,使用 MoEfication 方法构建,并在 GLUE [49] 任务上进行微调,这些任务包括 COLA [50]、QNLI [49]、RTE [5]、MNLI [52] 和 MRPC [9]。
• 视觉-语言任务 。视觉-语言任务遵循 MoE-LLaVA [31] 中的设置,其中用 StableLM-2-1.6B [4]、Qwen-1.8B [3] 和 Phi-2-2.7B [19] 作为主干语言模型,并使用 clip-vit-large-patch14-336 [39] 作为视觉编码器。这些模型在图像理解基准上进行评估,包括 VQA-v2 [14]、GQA [18]、VisWiz [17]、ScienceQA-IMG [34]、TextVQA [45]、POPE [30]、MME [54]、MMBench [33]、LLaVA-Bench (in-the-Wild) [32] 和 MM-Vet [55]。此外,在测试期间将路由记录保存在模型中。对于每个基准,收集每个 MoE 层的专家激活数和测试期间处理的 token 总数。
下表是视觉任务的比较:
下表是视觉-语言任务的比较:
下表是语言任务的比较: