本文目录
1 EDM2:分析和改进扩散模型的训练过程
(来自 NVIDIA)
1 论文解读
1.1 改良主流扩散模型的架构,以改善其训练过程
1.2 Baseline 架构介绍
1.3 一些初步的变化
1.4 标准化激活值的幅值
1.5 标准化权重和更新量
1.6 去除分组卷积 (配置 F)
1.7 保持激活值幅度的固定功能的层 (配置 G)
1.8 事后 EMA
1.9 实验结果
太长不看版
扩散模型在当前可以说主导了图像生成这个领域,也对于大数据集展现出了强大的缩放性。在本文中,作者在不改变 high-level 架构的前提下,识别和纠正了流行的 ADM 扩散模型中的几个训练方面不均匀的原因。在模型的训练过程中,作者观察到网络的激活值和权重值变化的幅度不受控制。因此,作者重新设计了网络架构来保持这个激活值和权重值变化的幅度稳定。这样可以消除在训练过程中观察到的漂移和不均衡现象,且没有太多改变网络原本的计算复杂度。本文的方法把 ImageNet 512×512 图像生成任务的 FID 由原来的 2.41 提高到了 1.81。生成质量和模型复杂度可视化如下图1所示。
此外,本文还提出一种在事后设置 exponential moving average (EMA) 的方法,即即在完成训练运行后设置 EMA。这允许在不执行多次训练的情况下精确调整 EMA 长度,并揭示了它与网络架构、训练时间和指导的交互。
图1:本文的工作显著提高了生成结果的质量,在 5 倍小的模型下超过了之前的最新技术
本文做了什么工作
在不改变整体架构的情况下,对 ADM 的 UNet 架构提出了一系列的改进,并展示出相当大的质量改进。
提出了一种事后 EMA 的策略,在训练结束之后使用 EMA,利用训练期间存储的权重快速地得到模型的权重。
1
EDM2:分析和改进扩散模型的训练过程
论文名称:Analyzing and Improving the Training Dynamics of Diffusion Models (CVPR 2024)
论文地址:
http://arxiv.org/pdf/2312.02696.pdf
代码地址:
http://github.com/NVlabs/edm2
1.1 改良主流扩散模型的架构,以改善其训练过程
基于文本,示例图片等等提示的高质量图像合成因为去噪扩散模型[1]的出现变得很流行。基于扩散模型的方法不仅能够产生高质量的图像,而且还可以提供多功能的控制[2],或者扩展到其他模态注入音频[3],视频[4]和 3D 形状[5]。
由于损失函数的高度随机性,扩散模型的训练过程十分具有挑战性。最终的图像质量由在整个采样链中预测的微弱图像细节决定,中间步骤的小错误在随后的迭代中可能会产生滚雪球效应。扩散模型的网络必须在不同的噪声等级和条件输入中准确地估计出下一步的干净的图片。这个过程十分困难,因为这些信号本身就是混沌且随机的。
为了在如此嘈杂的训练环境中有效地学习,理想情况下,网络应该对参数更新具有可预测的响应。作者认为,这种理想在当前最先进的设计中不满足,损害了模型的质量,并且由于超参数、网络设计和训练设置之间的复杂交互,因此很难改进它们。
本文的首要目标是了解扩散模型的训练动态为什么,或者说会因为什么意外的现象而变得不均衡,并且逐步去删除这些影响。本文方法的核心是权重值、激活值、梯度和权重更新的预期幅度,这些也在之前的工作[6][7][8]中被研究过。粗略地说,本文的方法是通过一组干净的设计来标准化所有幅度,这些设计解决它们的相互依赖性。
具体来说,作者
对 ADM
[9]
的 U-Net 架构进行了一系列修改,同时不改变其整体的结构
,并在此过程中展示出了很大的质量改进。最终网络可以认为是 ADM 架构的替代品。对于 ImageNet 512×512 图像生成任务,本文方法在使用或者不使用 guidance 的情况下达到了 1.81 和 1.91 的 FID,之前是 2.41 和 2.99。
本文还提出了一种在训练运行完成后设置 Exponential Moving Average (EMA) 参数的方法。模型平均[10]是所有高质量图像合成方法中不可或缺的技术[9][11][12][13]。但是,调节 EMA 超参数是一个很繁琐的过程,因为只有当训练接近收敛时,小范围的变化才会非常明显。本文提出的事后 EMA 允许根据训练期间存储的权重快速高效地重建网络,同时在计算上也比较高效。
1.2 Baseline 架构介绍
基线模型作者使用的是 ADM[9]架构,是通过 EDM[14]框架实现的。如下图2(a)所示。ADM 架构由 U-Net[15]和 Self-Attention[16]混合组成。作者使用 ImageNet[17]512×512 图像生成任务进行评估。与大多数高分辨率扩散模型一样,在预训练的Decoder[18]的 latent space 中运行,该 Decoder 执行 8× 的上采样。因此解码之前的输出维度是 64×64×4。在探索过程中,作者使用大小适中的网络配置,约 300M 的可训练参数,对 2147M 的图像进行训练,Batch Size 为 2048。
Baseline (配置 A):
由于原始 EDM 针对 RGB 图像,作者将输出通道计数增加到 4,并将训练数据集替换为 ImageNet-512 图像的 64×64×4 的 latent representation,全局标准化为零均值和标准偏差
。在此设置中,Baseline 的 FID 为 8.00。配置 A 的架构如下图2所示。
图2:EDM Basline 架构
1.3 一些初步的变化
改进的 Baseline (配置 B):
作者首先调整超参数 (学习率、EMA 长度、训练噪声水平分布等) 来优化 Baseline 模型的性能。作者还遵循之前的工作[1][19][20],禁用了 32×32 分辨率的 Self-Attention。
然后,作者解决了原始 EDM 训练设置中的一个缺点:虽然 EDM 中的损失权重在初始化时将所有噪声水平的损失幅度标准化为 1.0,但随着训练的进行,这种情况不再成立。然后,梯度反馈的大小在噪声水平之间变化,以不受控制的方式重新加权它们的相对贡献。
为了抵消这种影响,作者采用了 Kendall 等人[21]提出的多任务损失的连续泛化。作者将原始损失值跟踪为噪声水平的函数,并通过其倒数缩放训练损失。总之,这些变化将 FID 从 8.00 降低到 7.24。配置 B 的架构如下图 3(b) 和图4所示。
图3:本文基于的 ADM 架构。(a) Encoder 通过 Skip connections 连接到 Decoder,同时辅助的 Embedding 提供了噪声等级和类别的 Condition。(b) 原始构建块遵循 ResNet 的 Pre-Activation 设计
图4:配置 B 的架构
架构流线化 (配置 C):
为了便于对训练动力学的分析,作者继续简化架构。为了避免处理多种不同类型的可训练参数,作者从所有卷积层和线性层以及调节路径中去除加性 bias。为了恢复网络偏移数据的能力,作者将常数 1 的附加通道连接到网络的输入。作者使用[22]的初始化方法统一所有权重的初始化,从 ADM 的原始位置编码方案切换到更标准的傅里叶特征[23],并简化 Group Normalization 层。
最后,作者观察到,在训练过程中,由于 Key 和 Query 向量的大小增长,注意力图通常表现出尖刺。作者使用余弦注意力机制[24][25][26]在计算点积之前对向量进行归一化。这允许在整个网络中使用 16 位浮点数,提高了整体的效率。总之,这些变化将 FID 从 7.24 降低到 6.96。配置 C 的架构如下图5所示。
图5:配置 C 的架构
1.4 标准化激活值的幅值
通过简化架构,作者现在修复训练动态中的第1个问题:激活幅值。如下图6的第1行所示,尽管每个块中使用了 Group Normalization,但随着训练的进行,激活幅值会出现不可控的增长。作者认为这是由于 Encoder,Decoder 和 Self-Attention 的残差结构,ADM 网络包含较长的信号路径,且没有任何归一化。这些路径从残差分支累积,并且可以通过重复的卷积放大激活值。作者觉得这种激活值幅度的增长会将整个模型置于非最佳状态下训练。
图6:不同深度的激活值和权重大小随训练时间的变化
作者尝试将 Group Normalization 引入主路径中一起训练,但这会导致结果质量出现显著的下降。这可能与之前关于 StyleGAN[27]的发现有关,即网络的能力受到过度归一化的影响。受到 StyleGAN2[28]和其他一些工作[29][30]的启发,作者选择修改网络,以使得各个层和路径保持期望的激活幅值,目标是减少对归一化层的需求。
保持激活值幅度的层 (配置 D):
为了保留预期的激活幅度,作者将每一层的输出除以该层引起的激活幅度的预期缩放。为了恢复输入激活幅值,作者将每层的权重 channel-wise 地除以
。这个操作像不包含可学习输出缩放的权重归一化 (Weight Normalization[31]) 操作。由于整体权重大小不再对激活产生影响,因此作者使用单位高斯分布初始化所有权重。结果如上图3的中间所示,成功地消除了激活幅值的漂移。FID 也从 6.96 大幅提升至 3.75。配置 D 的架构如图7所示。
图7:配置 D 的架构
1.5 标准化权重和更新量
从上图3的中间可以看出,随着训练过程的进行,网络权重出现了明显的增长趋势。即使通过 Adam 优化器对梯度进行标准化,有效学习率 (即权重更新量的相对大小) 仍然随着训练的进行而衰减。虽然有人建议这种有效学习率衰减是一个理想的效果[31],但本文作者认为应该显式地控制它,而不是让它在层之间不可控和不均匀地漂移。因此,作者将其视为训练动态的另一个不平衡的问题。
控制有效的学习率 (配置 E):
作者在这里提出了一种 Forced Weight Normalization 技术,在每个训练步骤之前显式地将每个权重向量 \textbf{w}_i\textbf{w}_i 归一化为单位方差。同时,在训练期间仍然在此之上应用 "标准" 的权重归一化。
作者还引入了逆平方根的学习率衰减策略:
, 其中
是当前的 training iteration,
和
是超参数。
结果如图6的下方所示,在训练期间成功地保留了激活和权重大小,FID 从 3.75 提高到 3.02。配置 E 的架构如图8所示。
图8:配置 E 的架构
1.6 去除分组卷积 (配置 F)
在这一步中,作者去除具有潜在的有害结果的 Group Normalization[32] 。尽管网络在没有任何归一化层的情况下可以成功训练,但作者发现在 Encoder 的主路径引入更弱的 Pixel Normalization 层[33]仍然有好处。作者还从 Embedding 网络中删除了第2个线性层和网络输出的非线性,并将残差块中的重采样操作合入主路径中。FID 从 3.02 提高到 2.71。配置 F 的架构如图9所示。
图9:配置 F 的架构
1.7 保持激活值幅度的固定功能的层 (配置 G)
作者注意到网络里面仍然有没有保留激活值幅度的层。首先, 傅里叶特征的正弦函数和余弦函数没有单位方差,作者通过将它们放大
倍来纠正。其次,
非线性激活函数衰减了方差。因此, 作者将将输出除以
来补偿。
此外,每个分支之间可以通过可控制的参数来取得平衡[35],作者把加法操作换成加权求和。
还有两点改进:第1,作者在整个模型的结尾增加了一个可学习的,零初始化的标量的增益。第2,在每个残差块内的 Condition 信号应用类似的增益。因此在初始化时,相当于没有使用 Condition 信号。
如图10所示是最终的架构设计,它比基线更简单、更容易推理。FID 结果为 2.56,与当前技术水平相比极具竞争力。图11是配置 G 的架构。
图10:最终架构
图11:配置 G 的架构
如图12所示是在 ImageNet 512×512 生成任务中评估变化的影响。作者报告了没有 guidance 的 FID[36]结果,在 50000 个随机生成的图像和整个训练集之间计算。
图12:ImageNet 512×512 图像生成任务中评估变化的影响
1.8 事后 EMA
指数移动平均 (Exponential Moving Average, EMA) 是[19]在图像生成中扮演重要角色。但关于衰减参数与训练和采样的其他方面之间的关系知之甚少。作者开发了一种事后选择 EMA 文件的方法,即不需要在训练期间指定它。
EMA 的更新方式是:
, 式中,
是当前训练步骤。EMA 使得早期训练步骤的贡献呈现出指数级衰减。衰减率由通常接近 1 的常数
决定。
出于两个原因,作者提出使用基于幂函数而不是指数函数的衰减。其一,非常长的指数 EMA 对网络参数的初始阶段施加了不可忽略的权重,而初始阶段的参数常常是随机的。其次,作者观察到一个明显的趋势,即更长的训练运行受益于更长的 EMA 衰减,因此平均配置文件理想情况下应该随着训练时间自动扩展。
作者将时间步
的平均参数定义为:
其中常数
控制轮廓的锐度。通过这个公式,
的权重始终为零。
为了在实际中计算
, 作者在每个训练步骤之后执行增量更新, 如下所示:
这个更新方式类似于传统的 EMA, 但
取决于当前的训练时间。事后 EMA 的目标是为了在训练之后, 可以灵活选择
, 或者相对标准差
。为此, 作者在训练的过程中维持了两套平均参数向量
和
。其中,
, 对应的
分别是 0.05 和 0.1 。这些平均参数向量定期存储在训练期间保存的 Snapshots 中。在所有的实验中, 作者每大约 800 万个训练图像存储一次快照, 即每 4096 个训练步骤, Batch Size 为 2048 。
为了得到在训练期间或之后任意一点的任何 EMA profile 的参数
, 作者找到了存储的
和需求的 EMA profile 之间的最小二乘最优拟合, 并进行相应线性组合。结果如下图所示。
图13:通过 Snapshots 重建权重
图13的上方是为了模拟训练后任意长度的 EMA,作者在训练期间存储了许多平均网络参数的 Snapshots。每个阴影区域对应于网络参数的加权平均值。训练期间,维护了两个具有不同幂函数的 EMA 配置文件,存储在 8 个 Snapshots 中。下方虚线表示要合成的事后 EMA,紫色区域表示基于存储的 Snapshots 得到的最小二乘最优近似。使用存储在每个 Snapshot 的两组平均参数,重建权重的均方误差随着
的增加而减小, 实验上的阶数在
。在实践中,几十个 Snapshots 足够进行完美的重建。
下图14为 FID 如何根据配置 B-G 中的 EMA 长度而变化。可以看到,配置之间的最佳 EMA 长度差异很大。此外,当我们接近最终配置 G 时,最优值变得更窄,最初看起来很令人担忧。
图14:不同配置下的 FID 随着 EMA 长度的变化
但是作者随后又进行了另一个实验,结果如图15所示。在这个实验中,作者首先从网络的不同部分选择权重张量的子集。然后,对于每个选定的张量都执行一个扫描,其中只有所选张量的 EMA 发生了变化,而其他所有张量都保持在全局最优值。每个张量一行,把结果显示在图中,显示出对 FID 巨大的影响:在配置 B 中,FID 的提升可以达到 10%。一个实例使用非常短的 EMA,另一个使用非常长的 EMA。配置 B 对于最优 EMA 长度不敏感,因为其权重张量对于最优 EMA 长度没有达成一致。对于最终配置 G,这种效应消失,最优值更清晰:FID 没有显著的改进,张量现在就最优 EMA 达成一致。在配置 G 中,逐张量扫描改变 EMA 长度的效果很小。
图15:每层对于 EMA 长度的敏感度
图16说明了训练过程中最佳 EMA 长度的演变。尽管 EMA 长度的定义是相对于训练的长度,但作者观察到随着训练的进行,最优值在缓慢向着相对较长的 EMA 长度移动。
图16:训练过程中最佳 EMA 长度的演变
1.9 实验结果
作者使用 ImageNet 512×512 图像生成任务作为主要实验。图17对比了主要模型的结果。首先考虑不使用 guidance[37]的情况,之前最好的结果是 VDM++[38],FID 的值为 2.99。即使是使用小模型 EDM2-S 也取得了 2.56 的 FID,缩放模型尺寸之后可以进一步将 FID 提高到 1.91,大大超过了之前的记录。
图17:ImageNet-512 实验结果
作者还发现 Dropout[39][40]在表现出过拟合的情况下改善我们的结果。因此,作者在较大的配置 (M-XXL) 中使用 Dropout,这些配置显示过度拟合的迹象,同时在有害的较小配置 (XS, S) 中禁用 Dropout。