23年9月Hive AI的论文“REPAIR: REnormalizing Permuted Activations for Interpolation Repair”。
本文研究(Entezari 2021)的猜想,其发现如果考虑神经网络的排列不变性,那么 SGD 解之间的线性插值可能不存在损失屏障。首先,仅靠神经元对齐方法不足以在 SGD 解决方案之间建立低屏障线性连接,由于称为“方差崩溃”的现象:插值的深度网络,其激活的方差崩溃,导致性能不佳。提出
REPAIR(重新归一化排列的激活进行插值修复)
,它通过重新调整此类插值网络的预激活来减轻方差崩溃。该方法与归一化层、网络宽度和深度的选择之间进行交互,并证明在神经元对齐方法之上使用 REPAIR 可以在各种架构系列和任务中实现 60%-100% 的相对屏障减少 。
如图所示:REPAIR 通过减轻方差崩溃来提高插值网络的性能。在每个实验中,在两个独立训练网络的权重之间进行插值,这些网络的隐藏单元已进行对齐。然后,比较应用修正方法 REPAIR 之前和之后的插值网络。左:插值网络中激活的方差逐渐崩溃。其中每层的平均方差,通过原始端点网络中相应层的方差进行归一化。REPAIR 旨在纠正这种现象。中:在ImageNet上独立训练的对齐的ResNet50,REPAIR将其线性插值的屏障减少了 74%(从 76% 到 20%)。右:REPAIR 减少了架构、训练数据集和归一化层等多种选择之间的插值屏障。对于每个架构/数据集对,改变网络宽度;较大的marker表明更广泛的网络。
代码可如下获取
github.com/KellerJordan/REPAIR
独立训练的神经网络之间的插值问题:让 θ1, θ2 为两个这样网络权重向量,那么权重内插形式为 θα = (1 − α)θ1 + αθ2(0 < α < 1)的网络。这样的网络θα作为插值网络,并且以θ1、θ2作为端点网络。
一对网络之间的损失屏障 B(θ1,θ2) (Entezari et al., 2021)被定义为,即相对于两个端点损失的相应凸组合,沿 θ1 和 θ2 之间的线性路径
损失的最大增量
:
对于典型的神经架构,每层的神经元都可以在不改变网络功能的情况下进行排列;
这称为排列不变性,记P为排列组合矩阵。
在简单的前馈网络中,这相当于可以用 PWL 替换第 L 个权重矩阵,用 WL+1P^−1 替换第 (L + 1) 个权重矩阵,而不改变网络表示的函数。
因此,即使两个网络 θ1、θ2 在每一层都学习了一组功能相同的神经元,这些神经元也可能被任意排列或错位。
(Entezari2021) 推测,如果考虑排列不变性,那么在同一任务上训练的足够宽网络,其所有 SGD 解决方案都会变成线性模式连接,即它们之间没有屏障。
许多工作提出寻找一对神经网络的隐藏单元之间对齐的方法。(Li2015)提出最大化一批训练数据中配对神经元激活之间的相关性总和。也就是说,如果让 Xl,i (0) 和 Xl,i(1) 为对应于第 l 层(跨一批训练数据)的第 i 个隐单元激活的随机变量,则 (Li2015) 提出
优化排列
Pl ,
最大化以下目标
:
这相当于一个线性和分配(LSA)问题,对应于两个网络中隐单元对之间的相关矩阵;
这可以通过匈牙利算法来解决(Kuhn,1955)。
最近的工作提出了替代方法:
(He2018) 计算 Hessian 近似来对齐功能相似的神经元,(Singh & Jaggi 2020) 开发了一种基于传输的最佳软对齐方法。
(Ainsworth2022) 比较了三种方法,其中一种基于 (Li 2015)的方法和其他两个方法。
(Tatro2020)还基于最小化方程(2)来执行对齐,减少非线性插值的障碍。此外,用近端交替最小化(PAM)方案表明,这种对齐对目的来说几乎是最佳的。本文使用(Li 2015)的对齐方法。
对于具有剩余连接的网络,必须注意限制排列集以便网络表示的函数不会改变。特别是,隐单元的相同排列必须应用于馈入单个残差流的所有层。
对齐网络之间的屏障随宽度的增加而减小,对于在 MNIST 上训练的非常宽MLP 来说,屏障几乎为零。屏障也会随深度而急剧增加,对于 MLP 或超过几层的简单 CNN 来说会变得很大(Entezari,2021)。
即使是强大的基于最优传输的方法,其允许网络 A 中的每个神经元与网络 B 中神经元的加权和进行匹配来进行对齐,也不足以实现标准 ResNet (Singh & Jaggi,2020)之间的低屏障(低于 5% 测试误差)连接。
是什么导致多几层的插值网络性能快速下降?如图所示,对于深度 MLP,从一对已对齐的端点网络(在 MNIST 测试集上都具有高精度)进行插值,隐藏单元会经历方差崩溃(variance collapse)。也就是说,随着深入到网络中,它们激活的方差逐渐衰减,后面各层的激活几乎保持不变。对于每一层,按如下方式量化这种衰减。首先,测量一批训练数据中每个神经元激活的方差。然后,对该层中每个神经元的方差求和。最后,如果插值网络和两个端点网络分别表示为 vα,v1,v2,则插值比率为 (v1+v2)/2。计算网络中每一层的这个比率,给出在图(左)中报告的一系列值。
对于一个插值 ResNet18 的单层神经元方差集,参见图(左)。
一个插值的35 层 MLP 最后一层,方差衰减到接近零,这表明最后几层的激活几乎不变。
当直接在未对齐网络之间进行插值时,这种效应似乎会进一步加剧。
对 VGG(Simonyan & Zisserman,2014)和 ResNet50 架构重复此实验,分别在 CIFAR-10 和 ImageNet 上训练,发现最后层的方差衰减超过 上图的10 倍。
这是一个问题:
如果这些网络在其最后层中具有几乎不变的激活,那么它们甚至将不再能够区分输入。
对于插值网络第一层中的隐单元或通道,这样的单元在功能上等同于端点网络中各个单元之间的线性插值。
也就是说,如果用插值网络中的 Xα 和两个端点网络中的 X1,X2 表示单元的预激活(作为输入数据分布上的随机变量),则等式 Xα = (1 − α)X1 + αX2 成立 。
与 X1 或 X2 相比,Xα 的方差通常会减小。
如果两个端点网络完全对齐并且学习了相同的特征,那么应该有 corr(X1, X2) = 1。但在实践中,更典型的是成对的对齐单元(其对齐最小化方程(2)给出的成本函数) 相关性为 corr(X1, X2) ≈ 0.4。当考虑中点插值网络 (α = 0.5) 时,Xα 的
方差
由下式给出
通常有 std(X1) ≈ std(X2),因此可以简化为 Var(Xα) = (0.5 + 0.5 · corr(X1, X2)) · Var(X1)。
对于对齐网络,典型取值 corr(X1, X2) ≈ 0.4,这会产生 Var(Xα) = 0.7 · Var(X1):
与端点网络相比减少了 30%。
这种分析不能严格扩展到插值网络的更深层次,但直观上预计这种衰减会与深度的变化复合。
这种直觉与实验相符,随着 MLP、VGG 和 ResNet50 网络层的增加,方差崩溃变得更糟。
给定一个插值网络 θα = (1 − α) · θ1 + α · θ2,对于 0 < α < 1(具有对齐的端点网络 θ1, θ2),选择一组隐单元或通道,目标是纠正其统计数据。例如,对于 VGG 网络,纠正每个卷积层的预激活。对于 ResNet,纠正这些卷积预激活和每个残差块的输出。
目标是为每个选定通道计算一组仿射(重新缩放和平移)系数,以便校正所有选定通道的统计数据。考虑一个特定的通道,例如插值 ResNet18 中第 8 层第 45 个卷积通道。令 X1 和 X2 为两个端点网络中的通道值,将其视为输入训练数据上的随机变量,并令 Xα 为
插值网络
中的相同通道。那么希望以下两个
条件
成立:
而在进行任何校正之前,由于方差崩溃,通常会得到 std(Xα) ≪ min(std(X1), std(X2))。
如下提出两种算法来计算每个选定通道的合适仿射系数集,以引发这些条件。
首先提出一种高效的近似算法,该算法无需在内插网络中使用前向传播即可计算所需的仿射系数。考虑插值网络第一层中的隐藏单元。和之前一样,让 Xα 代表插值网络中的单元,X1, X2 分别代表两个端点网络中的相同单元。根据方程 Xα = (1 − α) · X1 + α · X2,该单元已经满足条件 (3)。给定 Var(X1)、Var(X2) 和 Cov(X1, X2) 值,可以根据以下公式精确计算 Xα 的方差
因此,为了满足该单元的条件(4),重新缩放系数β必须为