0. 这篇文章干了啥?
近年来,随着扩散模型(Diffusion Models)的发展,人工智能生成内容(AIGC)取得了显著进步。一方面,与经典的生成对抗网络(GAN)不同,扩散模型通过迭代细化噪声向量来生成高质量且细节丰富的结果。另一方面,这些模型在大规模数据对上训练后,能够在输入条件和输出结果之间展现出令人满意的一致性。这些能力推动了文本到图像生成领域的最新进展。得益于其出色的性能和开源社区的支持,Stable Diffusion(SD)成为最受欢迎的模型之一。
SD等模型的成功在很大程度上归功于其强大的去噪骨干结构。从带有注意力层的UNet架构到视觉Transformer,现有设计都严重依赖自注意力机制来管理空间标记之间的复杂关系。尽管它们表现优异,但自注意力操作固有的二次时间和内存复杂度对高分辨率视觉生成构成了重大挑战。使用FP16精度时,SD-v1.5因内存不足而无法在具有80GB内存的A100 GPU上生成2048分辨率的图像,这使得更高分辨率或更大模型的问题更加突出。
为了解决这些问题,本文旨在提出一种线性复杂度的全新标记混合机制,作为经典自注意力方法的替代方案。受最近引入的具有线性复杂度的模型(如Mamba)的启发,这些模型在序列生成任务中展现了巨大潜力,我们首先研究了它们在扩散模型中作为标记混合器的适用性。
然而,Mamba扩散模型存在两个缺点。一方面,当扩散模型在其训练尺度以外的分辨率下操作时,我们的理论分析表明特征分布容易发生偏移,导致跨分辨率推理困难。另一方面,扩散模型执行的是去噪任务而非自回归任务,允许模型同时访问所有噪声空间标记并基于整个输入生成去噪标记。相比之下,Mamba本质上是一个按顺序处理标记的RNN,意味着生成的标记仅基于先前的标记,这种约束被称为因果限制。将Mamba直接应用于扩散模型将对去噪过程施加不必要的因果限制,这既不合理又适得其反。虽然双向扫描分支可以在一定程度上缓解这个问题,但每个分支内的问题仍然存在。
针对Mamba在扩散模型中的上述缺点,我们提出了一种广义线性注意力范式。首先,为了解决训练分辨率与较大推理分辨率之间的分布偏移问题,我们为Mamba设计了一个归一化器,该归一化器由所有标记对当前标记的累积影响定义,并应用于聚合特征,确保无论输入尺度如何,总影响都保持一致。其次,我们旨在开发Mamba的非因果版本。我们从简单地移除遗忘门上的下三角因果掩码开始探索,但发现所有标记最终都会得到相同的隐藏状态,这削弱了模型的容量。为了解决这个问题,我们为不同标记引入了不同的遗忘门组,并提出了一种有效的低秩近似方法,使模型能够以线性注意力形式优雅地实现。我们对所提出的方法进行了技术分析,并与最近引入的线性复杂度标记混合器进行了比较,结果表明我们的模型可以视为这些流行模型的广义非因果版本。
将提出的广义线性注意力模块集成到SD的架构中,替换原始的自注意力层,得到的模型称为线性复杂度扩散模型(Linear-Complexity Diffusion Model),简称LinFusion。通过仅在知识蒸馏框架中训练线性注意力模块50k次迭代,LinFusion在性能上与原始SD相当甚至更优,同时显著降低了时间和内存复杂度。此外,它还提供了令人满意的零样本跨分辨率生成性能,并能在单个GPU上生成16K分辨率的图像。它还与SD的现有组件(如ControlNet)兼容,允许用户灵活地向提出的LinFusion注入额外控制,而无需任何额外的训练成本。
下面一起来阅读一下这项工作~
1. 论文信息
标题:LinFusion: 1 GPU, 1 Minute, 16K Image
作者:Songhua Liu, Weihao Yu, Zhenxiong Tan, Xinchao Wang
机构:National University of Singapore
原文链接:https://arxiv.org/abs/2409.02097
代码链接:https://github.com/Huage001/LinFusion
2. 摘要
现代扩散模型,特别是那些利用基于Transformer的UNet进行去噪的模型,严重依赖自注意力操作来处理复杂的空间关系,从而实现了令人印象深刻的生成性能。然而,这种现有范式在生成高分辨率视觉内容时面临着重大挑战,因为其时间和空间复杂度与空间标记的数量呈二次关系。为了克服这一限制,本文旨在提出一种新颖的线性注意力机制作为替代方案。具体而言,我们从最近引入的具有线性复杂度的模型(如Mamba、Mamba2和Gated Linear Attention)开始探索,并确定了两个关键特征——注意力归一化和非因果推断,这两个特征能够提升高分辨率视觉生成的性能。基于这些见解,我们引入了一种广义线性注意力范式,它作为多种流行线性标记混合器的低秩近似。为了节省训练成本并更好地利用预训练模型,我们初始化了我们的模型,并从预训练的StableDiffusion(SD)中提炼知识。我们发现,经过适度训练后,这种提炼后的模型(称为LinFusion)在性能上可与原始SD相媲美或更优,同时显著降低了时间和内存复杂度。在SD-v1.5、SD-v2.1和SD-XL上的大量实验表明,LinFusion提供了令人满意的零样本跨分辨率生成性能,能够生成高达16K分辨率的高分辨率图像。此外,它与预训练的SD组件(如ControlNet和IP-Adapter)高度兼容,无需进行任何适配工作。
3. 效果展示
如图1所示,在SD-v1.5、SD-v2.1和SD-XL上的广泛实验验证了所提出模型和方法的有效性。
4. 主要贡献
我们的贡献可以总结如下:
• 我们研究了Mamba的非因果和归一化感知版本,并提出了一种新颖的线性注意力机制,解决了使用扩散模型进行高分辨率视觉生成面临的挑战。
• 我们的理论分析表明,所提出的模型在技术上是对现有流行线性复杂度标记混合器的广义且高效的低秩近似。
• 在SD上的广泛实验表明,所提出的LinFusion不仅能达到甚至超过原始SD的结果,还展现了令人满意的零样本跨分辨率生成性能和与SD现有组件的兼容性。据我们所知,这是首次在SD系列模型上探索用于文本到图像生成的线性复杂度标记混合器。
5. 基本原理是啥?
在本文中,我们的目标是针对一般文本到图像的问题,提出一个具有与图像像素数量成线性复杂度的扩散骨干网络。为此,我们并没有从头开始训练一个新模型,而是从预训练的Stable Diffusion(SD)模型中初始化和提炼模型。具体来说,我们默认使用SD-v1.5模型,并将其自注意力机制——二次复杂度的主要来源——替换为我们提出的LinFusion模块。仅这些模块中的参数是可训练的,而模型的其他部分则保持不变。我们从原始SD模型中提炼知识到LinFusion中,以便在给定相同输入的情况下,它们的输出尽可能接近。图3提供了这一流程的概述。
这种方法带来了两个主要好处:(1)训练难度和计算开销显著降低,因为学生模型只需学习空间关系,而无需处理文本图像对齐等其他方面的复杂性;(2)所得模型与在原始SD模型及其微调变体上训练的现有组件高度兼容,因为我们仅将自注意力层替换为LinFusion模块,这些模块经过训练以在功能上与原始模块相似,同时保持整体架构不变。
从技术上讲,为了得到一个具有线性复杂度的扩散骨干网络,一个简单的解决方案是将所有自注意力块替换为Mamba2,如图4(a)所示。我们应用双向SSM以确保当前位置可以访问后续位置的信息。此外,Stable Diffusion中的自注意力模块没有像Mamba2中那样使用门控操作。如图4(b)所示,我们移除了这些结构以保持一致性,并稍微提高了性能。在本节的后续部分中,我们将深入探讨将Mamba2中的核心模块SSM应用于扩散模型的问题,并据此介绍LinFusion的关键特性:归一化和非因果性。最后,我们提供了训练目标以优化LinFusion模块中的参数。
6. 实验结果
7. 总结 & 未来工作
本文介绍了一种名为LinFusion的扩散骨干网络,用于生成文本到图像的模型,其复杂度与像素数量成线性关系。LinFusion的核心在于一种广义的线性注意力机制,该机制以其特有的归一化感知和非因果操作而区别于最近提出的如Mamba、Mamba2和GLA等线性复杂度标记混合器,这些方面在之前的研究中往往被忽视。
我们从理论上证明,所提出的范式为最近模型中的非因果变体提供了一种通用的低秩近似。基于Stable Diffusion(SD),经过知识提炼的LinFusion模块可以无缝替换原始模型中的自注意力层,从而确保LinFusion与Stable Diffusion的现有组件(如ControlNet、IP-Adapter和LoRA)高度兼容,而无需进一步训练。在SD-v1.5、SDv2.1和SD-XL上的广泛实验表明,所提出的模型优于现有基线,并在计算开销显著降低的情况下,实现了与原始SD相当或更优的性能。在单个GPU上,它能够支持高达16K分辨率的图像生成。
对更多实验结果和文章细节感兴趣的读者,可以阅读一下论文原文~
本文仅做学术分享,如有侵权,请联系删文。
3D视觉交流群,成立啦!
目前我们已经建立了3D视觉方向多个社群,包括
2D计算机视觉
、
最前沿
、
工业3D视觉
、
SLAM
、
自动驾驶
、
三维重建
、
无人机
等方向,细分群包括:
工业3D视觉
:相机标定、立体匹配、三维点云、结构光、机械臂抓取、缺陷检测、6D位姿估计、相位偏折术、Halcon、摄影测量、阵列相机、光度立体视觉等。
SLAM
:视觉SLAM、激光SLAM、语义SLAM、滤波算法、多传感器融合、多传感器标定、动态SLAM、MOT SLAM、NeRF SLAM、机器人导航等。
自动驾驶:深度估计、Transformer、毫米波|激光雷达|视觉摄像头传感器、多传感器标定、多传感器融合、自动驾驶综合群等、3D目标检测、路径规划、轨迹预测、3D点云分割、模型部署、车道线检测、Occupancy、目标跟踪等。
三维重建
:3DGS、NeRF、多视图几何、OpenMVS、MVSNet、colmap、纹理贴图等
无人机
:四旋翼建模、无人机飞控等
2D计算机视觉
:图像分类/分割、目标/检测、医学影像、GAN、OCR、2D缺陷检测、遥感测绘、超分辨率、人脸检测、行为识别、模型量化剪枝、迁移学习、人体姿态估计等
最前沿
:具身智能、大模型、Mamba、扩散模型等
除了这些,还有
求职
、
硬件选型
、
视觉产品落地、产品、行业新闻
等交流群
添加小助理: dddvision,备注:
研究方向+学校/公司+昵称
(如
3D点云+清华+小草莓
), 拉你入群。
▲长按扫码添加助理:cv3d008
3D视觉知识星球