本文介绍南京大学和阿里巴巴在扩散模型加速任务上的新工作:
SPLAM: Accelerating Image Generation with Sub-Path Linear Approximation Model
。本工作主要解决扩散模型在采样过程中需要多步导致推理速度较慢,针对现有的LCM存在的累积误差较大的问题进行优化,通过提出线性ODE采样方法,进一步提升了生图的质量和速度。在四步推理的设置下,在COCO30k和COCO5k上分别取得了10.06和20.77的FID分数,在加速模型方法中达到了SOTA效果。
论文标题:
SPLAM: Accelerating Image Generation with Sub-Path Linear Approximation Model
论文链接:
https://arxiv.org/abs/2404.13903
代码链接:
https://github.com/MCG-NJU/SPLA
项目主页:
https://subpath-linear-approx-model.github.io/
一、引言
扩散模型目前已经成为文本生成图片领域使用最为广泛的模型,其通过逐步去噪步骤来从一张高斯噪声采样生成真实分布中的图片。然而,扩散模型一直存在的一个问题是其运行速度,因为需要多步迭代推理,导致图片生成速度缓慢,计算开销大。
针对这个问题一直以来,也有非常多的工作在探索加速扩散模型的方法。在最初的DDPM中,模型的推理需要和训练时相同的1000步迭代,生成一张图片通常需要数分钟。一系列工作着重研究推理时的采样方法,如DDIM,DPM-Solver等,这些方法通过ODE等技术优化,将采样步数从1000步降低到了20~50步量级,大大提升了图片生成速度。另外一系列的工作着重研究基于现有预训练模型(比如Stable Diffusion),通过蒸馏等方法将步数进一步压缩,实现到了10步以下的采样迭代次数。
如一致性模型,通过将PF-ODE上的采样点映射到原点的思想,实现了2-4步的推理,然而压缩步数也会导致一定程度的图片质量下降。我们的论文主要分析了一致性优化学习的过程中的难点和导致性能下降的因素,提出了子路径线性近似模型(SPLAM)尝试缓解这些问题,实现了更小的累积误差,提升了模型性能。
二、方法简介
2.1 一致性模型
一致性模型(Consistency Model)[1] 是 OpenAI 的 Song Yang 博士在 ICML2023 提出的扩散模型加速方法,是这个领域中非常重要的一项工作,基于此在Stable Diffusion上开发的LCM模型 [2] 也是在用户社区中热度非常高加速功能插件,我们首先来回顾一下一致性模型的原理。
根据 Song Yang [3] 的理论,一个扩散模型的去噪过程可以建模为一条常微分方程ODE路径,称为概率流Probability-Flow ODE (PF-ODE):
而一致性模型的想法其实也非常简单,就是将ODE路径上每一个点都映射到原点,而原点来源于真实图片的分布,从而做到一步生成,如图所示:
具体地,我们希望学习一个函数
,对于一条ODE上的采样点
。在训练中,从
逐步采样到
通常时间开销过大,所以CM采取了一个训练技巧,在每一步训练迭代中通过缩小相邻两个点间的映射误差,来逐渐最终达到一致性。然而这也带来了问题,逐步的收敛导致了较大的累积误差:
使得在生图时的图片的细节丢失较多,生图质量较差。我们的方法也是针对这个问题,通过在每个子路径上通过随机线性插值采样,来进行连续的渐进式的误差估计,做到累计误差更小的去噪映射。
2.2 问题分析
对于上面提到的一步生成模型,我们通常把映射函数
参数化为:
根据EDM中的理论,我们可以设计一个 canonical denoiser function:
,而其去
噪目标就为
:
。
这时会存在一个问题,这个目标其实比较难以优化,原因在于随着时间步
的增加,
会逐渐趋向于零,这会使得训练不稳定有可能塌缩。一致性模型其实一定程度上缓解了这个问题,当我们假设模型理想地收敛,即
,这个性质能够对于上式进行一个预估:
。然后我们把
的表达式代入,得到一个基于
的误差估计:
因为额外有了系数
,所以上面所提到的问题被一定程度的缓解。
现在我们再来具体分析一下这个优化目标,我们可以把它解耦为两项:
(1)
,这一项衡量了由于漂移和扩散过程导致的从
到
的增量距离。
(2)
,这一项衡量了前一时间步的去噪贡献,这些贡献会连续地传播到后续时间步。
这时,我们我们就可以把这个优化目标重写为一个子路径(Sub-Path)的优化目标:
在这个目标式中,
这项是导致累积误差的关键,我们也着重对于这项进行优化。
三、SPLAM
基于此,我们提出了我们的加速方法
子路径线性近似模型
(Sub-Path Linear Approximation Model,SPLAM),如图所示。
首先,我们提出了
子路径线性
ODE(Sub-Path Linear ODE,SL-ODE),来近似原始PF-ODE上的子段,由此来进行对于
的递进式估计。具体来说,对于原始路径上的一段
,基于
我们对两个端点
进行插值形成线性路径,在这个线性路径上的采样点可以表示为:
因为
符合由PF-ODE控制的分布,我们的线性变换有效地定义了一个对于
的线性ODE:
即为SL-ODE。注意到这里对于端点多了一项漂移系数
,这项系数的引入具体可参考我们论文中的详细推导。据此,我们也有了对应
的Denoiser和生成表达式:
将这个式子代入上面的子路径优化目标,便得到了我们SPLAM的最终优化目标
:
这个目标对于原本较难优化的
项提供了一种递进式的拟合,这也使得我们我们的训练可以使用更大的推理步长。
由此,我们也以预训练好的Stable Diffusion模型作为PF-ODE,来建立我们的SL-ODE,并提出了基于SPLAM目标的蒸馏方法(Sub-Path Linear Approximation Distillation,SPLAD)。我们依然沿用CM中的生成函数的参数表达式,除了额外增加了一个维度
:
其中
为使用的教师模型作为 ODE Solver。