点击下方
卡片
,关注
「3DCV」
公众号
选择
星标
,干货第一时间送达
来源:3DCV
添加小助理:cv3d008,备注:方向+学校/公司+昵称,拉你入群。文末附3D视觉行业细分群。
扫描下方二维码,加入「
3D视觉从入门到精通
」知识星球
,星球内凝聚了众多3D视觉实战问题,以及各个模块的学习资料:
近20门独家秘制视频课程
、
最新顶会论文
、计算机视觉书籍
、
优质3D视觉算法源码
等。想要入门3D视觉、做项目、搞科研,欢迎扫码加入!
0. 论文信息
标题:DDIL: Improved Diffusion Distillation With Imitation Learning
作者:Risheek Garrepalli, Shweta Mahajan, Munawar Hayat, Fatih Porikli
机构:Qualcomm AI Research
原文链接:https://arxiv.org/abs/2410.11971
1. 摘要
扩散模型擅长生成性建模(例如,文本到图像),但是采样需要多个去噪网络通道,从而限制了实用性。诸如渐进蒸馏或稠度蒸馏之类的努力已经显示出通过以所产生样品的质量为代价来减少通过次数的前景。在这项工作中,我们确定了共变移位是多步提取模型在推理时由于复合误差而表现不佳的原因之一。为了解决共变量转移,我们在模仿学习(DDIL)框架内制定扩散提取,并增强训练分布,以提取数据分布(前向扩散)和学生诱导分布(后向扩散)上的扩散模型。关于数据分布的培训有助于通过保持边际数据分布来使世代多样化,关于学生分布的培训通过纠正协变量偏移来解决复合误差。此外,我们采用反射扩散公式进行蒸馏,并在不同的蒸馏方法中证明了改进的性能和稳定的训练。我们表明,DDIL一致性改进了基线算法的渐进蒸馏(PD),潜在的一致性模型(LCM)和分布匹配蒸馏(DMD2)。
2. 引言
扩散模型虽然能够生成高质量的图像,但由于其迭代去噪过程,采样速度较慢。为解决这一问题,已提出蒸馏技术来减少去噪步骤的数量。这些技术大致可分为轨迹级和分布匹配方法。前者侧重于在每个样本层面保留教师模型的轨迹,而后者则匹配边缘分布。
推荐课程:
国内首个面向具身智能方向的理论与实战课程
。
多步学生模型在平衡质量和计算效率方面提供了一种有前景的方法。然而,它们通常面临一个关键挑战:协变量偏移。这发生在学生模型在训练期间遇到的噪声输入潜在变量的分布与学生进行推理时遇到的分布不同时。这种不匹配会显著影响生成质量,尤其是在去噪步骤数量较少的情况下。近期工作仅考虑反向轨迹以获得对生成质量的反馈,但这些方法通常对数据分布漠不关心,并可能出现模式崩溃。
在本工作中,我们确定了“协变量偏移”是影响多步蒸馏扩散模型生成质量的关键因素。为解决协变量偏移并保留多样性,我们通过改进蒸馏的训练分布,在模仿学习(DDIL)框架内引入了扩散蒸馏。我们通过结合数据分布(正向扩散)和学生模型的预测分布(推理时的反向轨迹)来实现这一点。这种方法结合了以下优点:(1)保留边缘数据分布:基于数据分布的训练确保了学生模型保持了原始数据的固有统计特性;(2)校正协变量偏移:基于反向轨迹的训练使学生模型能够识别和适应协变量偏移,从而提高了得分估计的准确性,特别是在少步设置下。
3. 效果展示
不同提取技术生成图像的定性比较。
4. 主要贡献
我们做出了以下贡献:
• 我们提出了一种新颖的DDIL框架,该框架通过在数据集聚合“DAgger”框架内对数据分布(正向)和学生诱导分布(反向轨迹)进行蒸馏,增强了扩散蒸馏的训练分布,从而产生了改进的聚合预测分布和更好的覆盖率。
• 为提高扩散模型中蒸馏过程的稳定性,我们对教师和学生扩散模型采用阈值处理,并结合反射扩散来强化数据分布的支持进行蒸馏。因此,这种方法在与DDIL结合时进一步缓解了协变量偏移,带来了更实质性的改进。
• 我们证明了我们的DDIL方法能够生成多样化的样本,并在计算效率高的框架中持续优于不同的蒸馏技术,如渐进蒸馏(PD)、潜在一致性蒸馏(LCM)和分布匹配蒸馏(DMD2)。
5. 方法
我们在图2中展示了在渐进蒸馏背景下DDIL框架的实例化。
我们引入了模仿学习中的扩散蒸馏(DDIL),这是一个受模仿学习中DAgger算法启发的新框架,用于增强中间噪声潜在变量的采样分布,以蒸馏扩散模型。扩散模型蒸馏涉及两个关键考虑因素:(1)学生模型遇到的潜在状态的训练分布;(2)蒸馏期间采用的反馈机制。DDIL特别侧重于改进训练分布,同时对学生采用的不同蒸馏技术所使用的具体反馈机制保持中立。
为实现这一目标,DDIL从三个来源战略性地采样中间潜在变量:(1)数据集的正向扩散,由采样先验βfrwd捕获(如算法1所示);(2)学生模型的反向轨迹(展开潜在变量),由采样先验βstudent_bckwrd表示;(3)教师模型的反向轨迹,由采样先验βteacher_bckwrd表示,这在无数据设置下特别有利于保留边缘数据分布。结合这些采样策略可提高蒸馏性能。
DDIL是一个关于蒸馏扩散模型的统一训练框架,就蒸馏的采样先验而言。DDIL结合了教师对学生轨迹的反馈,在渐进蒸馏和潜在一致性模型(LCM)的情况下,这与DAgger Ross et al.的原则相一致。此外,虽然Kohler et al.和Yin et al.等方法仅对反向轨迹进行蒸馏,且在蒸馏期间不考虑边缘数据分布,但DDIL通过持续结合所选蒸馏算法在正向和反向轨迹上的反馈来解决这一问题。因此,我们的灵活框架允许改进训练分布,从而提升扩散蒸馏方法的性能。
算法1概述了模仿学习中的扩散蒸馏(DDIL)的通用框架。该框架利用预训练的扩散模型(教师)和学生扩散模型,后者通常用教师的参数初始化。此外,假设可以访问真实数据,在蒸馏过程中提供边缘数据分布的代表性样本。该框架需要为教师和学生模型指定超参数,包括它们各自的离散化方案。为简化起见,我们假设在1中使用DDIM求解器。蒸馏通过随机选择三种方法中的一种来采样输入到学生模型的中间噪声潜在变量来进行。此选择由用户定义的采样先验控制:βfrwd、βteach_bckwrd和βstudent_bckwrd,它们分别对应于之前讨论过的三种中间潜在变量的来源。这些采样先验(表示为βi)的选择和更新可以根据训练阶段、目标函数和整体任务目标进行定制。
6. 实验结果
7. 总结 & 未来工作
本文提出了DDIL,这是一种新颖的扩散模型蒸馏框架,旨在解决协变量偏移的挑战,同时保持边际数据分布。将DDIL与已建立的蒸馏技术相结合,包括渐进蒸馏(Progressive Distillation)、一致性蒸馏(LCM,即Label Consistency Matching)和基于分布匹配的蒸馏(DMD2,即Distribution Matching based Distillation 2),可以持续获得定量和定性的改进。此外,我们还表明,在DMD2框架内整合DDIL,能够增强训练稳定性,减小所需的批量大小,并提高计算效率,从而证明了其更广泛的适用性和实用价值。
对更多实验结果和文章细节感兴趣的读者,可以阅读一下论文原文~