本文介绍上科大YesAI Lab 发表在ICML 2024关于Diffusion Model的工作《Guidance with Spherical Gaussian Constraint for Conditional Diffusion》。本工作旨在利用预训练的扩散模型实现损失函数引导的、无需训练的条件生成任务。本工作上海科技大学2023级研究生杨凌霄为第一作者,由石野教授指导完成。
论文地址:
https://arxiv.org/abs/2402.03201
代码链接:
https://github.com/LingxiaoYang2023/DSG2024
摘要
最近的Guidance方法试图通过利用预训练的扩散模型实现损失函数引导的、无需训练的条件生成。虽然这些方法取得了一定的成功,
但它们通常会损失生成样本的质量,并且只能使用较小的Guidance步长,从而导致较长的采样过程。
在本文中,我们揭示了导致这一现象的原因,即采样过程中的
流形偏离(Manifold Deviation)
。我们通过建立引导过程中估计误差的下界,从理论上证明了流形偏离的存在。
为了解决这个问题,我们提出了基于球形高斯约束的Guidance方法(DSG),通过解决一个优化问题将Guidance步长约束在中间数据流形内,使得更大的引导步长可以被使用。
此外,我们提出了该DSG的闭式解(Closed-Form Solution), 仅用几行代码,就能够使得DSG可以无缝地插入(Plug-and-Play)到现有的无需训练的条件扩散方法,在几乎不产生额外的计算开销的同时大幅改善了模型性能。我们在各个条件生成任务(Inpainting, Super Resolution, Gaussian Deblurring, Text-Segmentation Guidance, Style Guidance, Text-Style Guidance, and FaceID Guidance)中验证了DSG的有效性。
背景:无需训练的条件扩散模型
Classifier guidance首先提出使用预训练的扩散模型进行条件生成。它利用贝叶斯公式
, 通过引入额外的似然项
来实现条件生成:
目前无需训练的方法, 将time-dependent classifier替换成某个定义在
上的可微损失函数
,并利用Tweedie's formula求解额外的似然项:
这里
表示加噪t步的data,
表示引导步长。因此, 总体的采样过程可以被写成
损失函数引导过程中的流形偏离(Manifold Deviation)
尽管先前的工作由于其灵活的特性在各种条件生成任务中取得了巨大成功,但它们会牺牲生成样本的质量。在本文中,我们提出这种现象产生的原因是线性流形假设(Linear Manifold Assumption)和Jensen Gap导致的流形偏离:
-
线性流形假设:线性流形假设是一个相当强的假设, 因此在实践中通常会引入误差。
-
本文指出, 即使DPS提供了Jensen Gap的上界, 它仍然具有下界, 也会引入估计误差:
基于球面高斯约束引导的条件扩散模型(DSG)
既然无论Jensen Gap还是线性流形假设都会不可避免地引入估计误差,那么为什么不在已经无条件的中间数据流形(Intermediate Data Manifold)中,找到那个最接近条件采样的点呢?
因此,我们提出了DSG(
D
iffusion with
S
pherical
G
aussian constraint),一种在无条件中间流形M_t的高置信区间内进行Guidance的优化方法:
这里
表示高斯分布的概率为
的置信区间。在这个优化问题中目标函数倾向于让采样过程在梯度下降方向进行,约束则是将采样约束在高斯分布的高置信区间。
然而,当高置信区间包含n维空间中时,优化问题就变得具有挑战性。幸运的是,高维各向同性高斯分布的高置信区间集中在一个超球上,我们可以通过用这个超球近似它来简化约束,称为球面高斯约束(Spherical Gaussian Constraint):
这里
表示n维高斯分布近似的超球。通过这种近似方法,我们能够得到优化问题的闭式解:
这个闭式解的求得能够表明,DSG可以无缝插入目前的无需训练的条件扩散模型,如DPS、Freedom、UGD,而不造成额外的计算复杂度。并且,只需要修改几行代码就能够产生更好的样本和达到更快的推理速度。
另外, 从另一个角度看, DSG也可以看成在预测均值
上进行梯度下降。而且, 由于
与
正相关, DSG可以看作是自适应的梯度下降方法, 在一开始下降步长大, 在最后下降步长小。在实验中, 我们发现DSG最大的步长能够达到DPS的400倍, 因此能够在更小的DDIM steps下相比于 DPS更加鲁棒。
此外,我们发现DSG虽然增强了对齐能力和真实性,但是在多样性方面有所损失。因此,我们对原始采样方向和梯度下降方向的进行加权,就像Classifier-free Guidance那样:
这里