近日,来自清华大学智能产业研究院(AIR)助理教授赵昊老师的团队,联合戴姆勒公司,提出了一种无需训练的多域感知模型融合新方法。研究重点关注场景理解模型的多目标域自适应,并提出了一个挑战性的问题:
如何在无需训练数据的条件下,合并在不同域上独立训练的模型实现跨领域的感知能力?
团队给出了“Merging Parameters + Merging Buffers”的解决方案,这一方法简单有效,在无须访问训练数据的条件下,能够实现与多目标域数据混合训练相当的结果。
论文链接:
https://arxiv.org/pdf/2407.13771
项目地址:
https://air-discover.github.io/ModelMerging/
1 背景介绍
一个适用于世界各地自动驾驶场景的感知模型,需要能够在各个领域(比如不同时间、天气和城市)中都输出可靠的结果。然而,典型的监督学习方法严重依赖于需要大量人力标注的像素级注释,这严重阻碍了这些场景的可扩展性。因此,多目标域自适应(Multi-target Domain Adaptation, MTDA)的研究变得越来越重要。多目标域自适应通过设计某种策略,在训练期间同时利用来自多个目标域的无标签数据以及源域的有标签合成数据,来增强这些模型在不同目标域上的鲁棒性。
与传统的单目标域自适应 (Single-target Domain Adaptation, STDA)相比,MTDA 面临更大的挑战——一个模型需要在多个目标域中都能很好工作。为了解决这个问题,以前的方法采用了各种专家模型之间的一致性学习和在线知识蒸馏来构建各目标域通用的学生模型。尽管如此,这些方法的一个重大限制是它们需要同时使用所有目标数据,如图1(b) 所示。
但是,
同时访问到所有目标数据是不切实际的
。一方面原因是数据传输成本限制,因为包含数千张图像的数据集可能会达到数百 GB。另一方面,从数据隐私保护的角度出发,不同地域间自动驾驶街景数据的共享或传输可能会受到限制。面对这些挑战,在本文中,我们聚焦于一个全新的问题,如图1(c) 所示。我们的研究任务仍然是MTDA,但我们并没有来自多个目标域的数据,而是只能获得各自独立训练的模型。我们的目标是,通过某种融合方式,将这些模型集成为一个能够适用于各个目标域的模型。
图1:不同实验设置的对比
2 方法
如何将多个模型合并为一个,同时保留它们在各自领域的能力?我们提出的解决方案主要包括两部分:Merging Parameters(即可学习层的weight和bias)和 Merging Buffers(即normalization layers的参数)。在第一阶段,我们从针对不同单目标域的无监督域自适应模型中,得到训练后的感知模型。然后,在第二阶段,利用我们提出的方法,
在无须获取任何训练数据的条件下,只对模型做合并
,得到一个在多目标域都能工作的感知模型。
图2:整体实验流程
下面,我们将详细介绍这两种合并的技术细节和研究动机。
2.1 Merging Parameters
2.1.1 Permutation-based的方法出现退化
事实上, 如何将模型之间可学习层的 weight 和 bias 合并一直是一个前沿研究领域。在之前的工作中, 有一种称为基于置换 (Permutation-based)的方法。这些方法基于这样的假设:当考虑神经网络隐藏层的所有潜在排列对称性时, loss landscape 通常形成单个盆地 (single basin) 。因此, 在合并模型参数
和
时, 这类方法的主要目标是找到一组置换变换
, 确保
在功能上等同于
, 同时也位于参考模型
附近的近似凸盆地 (convex basin) 内。之后, 通过简单的中点合并
以获得一个合并后的模型
,该模型能够表现出比单个模型更好的泛化能力,
在我们的实验中, 模型
和
在第一阶段都使用相同的网络架构进行训练, 并且, 源数据都使用相同的合成图像和标签。我们最初尝试采用了一种 Permutation-based 的代表性方法——Git Re-Basin, 该方法将寻找置换对称变换的问题转化为线性分配问题 (LAP), 是目前最高效实用的算法。
图3:Git Re-basin和mid-point的实验结果对比
但是,如图3所示,我们的实验结果出乎意料地表明,
不同网络架构(ResNet50、ResNet101 和 MiT-B5)下 Git Re-Basin 的性能与简单中点合并相同
。进一步的研究表明,Git Re-Basin 发现的排列变换在解决 LAP 的迭代中保持相同的排列,这表明在我们的领域适应场景下,Git Re-Basin 退化为一种简单的中点合并方法。
2.1.2 线性模式连通性的分析
我们从线性模式连通性(linear mode connectivity)的视角进一步研究上述退化问题。具体来说, 我们使用连续曲线
在参数空间中连接模型
和模型
。在这种特定情况下, 我们考虑如下线性路径,
接下来, 我们通过对
做插值遍历评估模型的性能。为了衡量这些模型在两个指定目标域 (分别表示为
和
) 上的有效性, 我们使用调和平均值 (Harmonic Mean) 作为主要评估指标,
我们之所以选择调和平均值作为指标,是因为它能够赋予较小的值更大的权重,这能够更好应对世界各地各个城市中最差的情况。它有效地惩罚了模型在一个目标域(例如,在发达的大城市)的表现异常高,而其他目标域(例如,在第三世界乡村)表现低的情况。不同插值的实验结果如图4(a)所示。“CS”和“IDD”分别表示目标数据集 Cityscapes 和 Indian Driving Dataset。
图4:线性模式连通性的分析实验
2.1.3 理解线性模式连通性的原因
在上述实验结果的基础上,我们进一步探究:在先前域自适应方法中观察到的线性模式连通性,背后的根本原因是什么?为此,我们进行了消融实验,来研究第一阶段训练
和
期间的几个影响因素。
-
合成数据。
使用相同的合成数据可以作为两个域之间的桥梁。为了评估这一点,我们将合成数据集 GTA 中的训练数据划分为两个不同的非重叠子集,每个子子集包含原始训练样本的 30%。在划分过程中,我们将合成数据集提供的具有相同场景标识的图像分组到同一个子集中,而具有显着差异的场景则放在单独的子集中。我们使用这两个不同子集分别作为源域,训练两个单目标域自适应模型(目标域为 CityScapes 数据集)。随后,我们研究这两个 STDA 模型的线性模式连通性。结果如图 4(b) 所示,可以观察到,在参数空间内连接两个模型的线性曲线上,性能没有明显下降。这一观察结果表明,使用相同的合成数据并不是影响线性模式连通性的主要因素。
-
自训练架构。
使用教师-学生模型可能会将最后的模型限制在 loss landscape 的同一 basin 中。为了评估这种可能性,我们禁用了教师模型的指数移动平均 (EMA) 更新。相应地,我们在每次迭代中将学生权重直接复制到教师模型中。随后,我们继续训练两个单目标域自适应模型,分别利用 GTA 作为源域,Cityscapes 和 IDD 作为目标域。然后,我们研究在参数空间内连接两个模型的线性曲线,结果如图 4(c) 所示。我们可以看到线性模式连接属性保持不变。
-
初始化和预训练。
使用相同的预训练权重初始化 backbone 的做法,可能会使模型在训练过程中难以摆脱的某一 basin。为了验证这种潜在情况,我们初始化两个具有不同权重的独立 backbone,然后继续针对 Cityscapes 和 IDD 进行域自适应。在评估两个收敛模型之间的线性插值模型时,我们观察到性能明显下降,如图 4(d) 所示。为了更深入地了解潜在因素,我们继续探究,
是相同的初始权重,还是预训练过程导致了这种影响?
我们初始化两个具有相同权重但没有预训练的主干,然后再次进行实验。有趣的是,我们发现,在参数空间的线性连接曲线仍然遇到了巨大的性能障碍,如图 4(e) 所示。这意味着
预训练过程在模型中的线性模式连接方面起着关键作用。
2.1.4 关于合并参数的小结
我们通过大量实验证明,当领域自适应模型
从相同的预训练权重开始时
,模型可以有效地过渡到不同的目标领域,同时仍然保持参数空间中的线性模式连通性。因此,这些训练模型可以通过
简单的中点合并
,得到在两个领域都有效的合并模型。
2.2 Merging Buffers
Buffers,即批量归一化 (BN) 层的均值和方差,与数据域密切相关。因为数据不同的方差和均值代表了域的某些特定特征。在合并模型时如何有效地合并 Buffers 的问题通常被忽视,因为现有方法主要探究
如何合并在同一域内的不同子集上训练的两个模型
。在这样的前提下,之前的合并方法不考虑 Buffers 是合理的,因为来自任何给定模型的 Buffers 都可以被视为对整个总体的无偏估计,尽管它完全来自随机数据子样本。
但是, 在我们的实验环境中, 我们正在研究如何合并在完全不同的目标域中训练的两个模型, 这使得 Buffers 合并的问题不再简单。由于我们假设在模型 A 和模型 B 的合并阶段无法访问任何形式的训练数据, 因此我们可用的信息仅限于 Buffers 集
。其中,
表示 BN 层的数量, 而
和
分别表示第
层的平均值、标准差和 tracked 的批次数。生成 BN 层的统计数据如下:
以上方程背后的原理可以解释如下:引入 BN 层是为了缓解内部协变量偏移 (internal covariate shift) 问题, 其中输入的均值和方差在通过内部可学习层时会发生变化。在这种情况下, 我们的基本假设是, 后续可学习层合并的 BN 层的输出遵循正态分布。由于生成的 BN 层保持符合高斯先验的输入归纳偏差, 我们根据从
和
得到的结果估计
和
。如图5所示, 我们获得了从该高斯先验中采样的两组数据点的均值和方差, 以及这些集合的大小。我们利用这些值来估计该分布的参数。
图5:合并BN层的示意图
当将 Merging Buffers 方法扩展到
个高斯分布时, tracked 的批次数