太长不看版
端侧文生图扩散模型的成功范式。
现有的文生图 (T2I) 扩散模型有几个限制:
1) 模型尺寸过大不适合移动设备 (Mobile Devices),2) 时延高,3) 生成质量很低
。
本文开发了一个
很小,快速的 T2I 模型,旨在在移动平台上生成高分辨率和高质量的图像
。本文提出了几个技术来实现这个目的。
首先,作者系统地检查了网络架构的设计选择,以减少模型参数和延迟,同时确保高质量的生成。其次,为了进一步提高生成质量,使用来自更大模型的跨架构知识蒸馏,使用多级策略从头开始指导小模型的训练。然后,通过将对抗性指导与知识蒸馏相结合来实现 Few-step 生成。
本文的模型 SnapGen 可以约 1.4s 在移动设备上生成 1024px 的图像。在 ImageNet-1K 上,本文的模型仅使用 372M 参数,在 256 px 生成中实现了 2.06 的 FID。在 T2I 基准测试中 (GenEval 和 DPG-Bench),本文的模型只有 379M 参数,虽然尺寸很小,却超过了具有数十亿个参数的大模型 (比 SDXL 小 7 倍,比 IF-XL 小 14 倍)。
图1:各种文生图模型在模型大小、移动设备兼容性和视觉输出质量方面的比较。本文模型仅使用 379M 参数,展示了具有竞争力的视觉质量,同时与移动设备相兼容。所有图像分辨率均为 1024px
下面是对本文的详细介绍。
本文目录
1 SnapGen:轻量化架构和训练策略实现端侧文生图
(来自 Snap,墨尔本大学,HKUST,MBZUAI)
1 SnapGen 论文解读
1.1 SnapGen 研究背景
1.2 高效 U-Net 架构
1.3 更小更快的解码器
1.4 训练配方以及多级知识蒸馏
1.5 Step 蒸馏
1.6 实验设置
1.7 实验结果
1
SnapGen:轻量化架构和训练策略实现端侧文生图
论文名称:SnapGen: Taming High-Resolution Text-to-Image Models for Mobile Devices with Efficient Architectures and Training
论文地址:
http://arxiv.org/pdf/2412.09619
Project Page:
http://snap-research.github.io/snapgen/
1.1 SnapGen 研究背景
大规模文生图 (T2I) 扩散模型在内容生成方面取得了显著的成功,为图像编辑和视频生成等许多应用提供支持。然而,T2I 模型往往伴随着较大的模型尺寸,较慢的运行时间。如果将它们部署在云上,会引发与数据安全性问题,和高成本的问题。
为了应对这些挑战,人们通过模型压缩 (例如剪枝和量化) 等技术开发更小更快的 T2I 模型,比如通过蒸馏减少 steps
[1]
,以及减轻二次方计算复杂度的高效注意力机制
[2]
。但是,目前的工作仍然会遇到局限性,例如移动设备上的低分辨率生成,这限制了它们更广泛的应用。
更重要的是,一个关键问题尚未探索:
如何从头开始训练 T2I 模型,以在移动上生成高质量高分辨率图像?
这样的模型将在速度、紧凑性、成本效益和安全性部署方面提供实质性优势。为了构建这个模型,本文引入几个创新:
-
高效的网络架构:
作者对网络架构进行了深入的检查,包括 Denoising UNet 和 AutoEncoder (AE),以获得资源使用和性能之间最优的权衡。与优化和压缩预训练扩散模型的先前工作
[3][4][5]
不同,本文直接专注于宏观和微观级别的设计选择,以实现一种新颖的架构,该架构大大降低了模型大小和计算复杂度,同时保留了高质量的生成。
-
改进的训练技术:
作者引入了几个改进来从头开始训练紧凑的 T2I 模型。利用 Flow Matching
[6][7]
作为目标,与SD3 和 SD3.5 等更大的模型对齐。这种设计实现了高效的 Knowledge Distillation 和 Step Distillation,将大规模扩散模型的丰富表示转移到小模型。此外,本文提出了一种多级知识蒸馏和一个结合了多个训练目标的时间步感知缩放。没有通过线性组合对目标进行加权,而是考虑流匹配中不同时间步长的目标预测难度 (即学生-教师的差异)。
-
高级 Step Distillation:
作者通过使用 Few-step 的教师模型 (即 SD3.5-Large-Turbo
[8]
) 将对抗训练和知识蒸馏相结合,对本文模型执行 Step Distillation,从而实现仅 4 或 8 步的超快高质量生成。
本文介绍如何制作和训练高效的 T2I 模型以进行高分辨率生成。具体来说,从 latent diffusion model 架构开始,优化了 denoising backbone 和 autoencoder,使它们又紧凑又快速,即使在移动设备上也是如此。然后,本文提出了改进的训练配方和知识蒸馏,得到高性能 T2I 模型。最后介绍步骤蒸馏,显著降低更快的 T2I 模型的 denoising steps 数量。
1.2 高效 U-Net 架构
Baseline 架构
作者从 SDXL
[9]
中选择 UNet 作为 Baseline,因为它比纯基于 Transformer 的模型具有更高的效率和更快的收敛。将 U-Net 调整为更薄和更短的模型 (将变换器块的数量从 [0, 2, 10] 分为三个阶段到 [0, 2, 4],它们的通道维度从 [320, 640, 1280] 减少到 [256, 512, 896]),并在其之上迭代设计选择。
评估指标
作者在 ImageNet-1K 上训练模型 120 个 epoch,除非另有说明,并报告 256px 生成的 FID 分数。与现有工作
[10]
类似,作者通过文本模板 "a photo of
"。然后,使用文本编码器对其进行编码,以对齐 T2I 生成的管道。作者还计算了不同模型的参数量,FLOPs (以 128 × 128 的 latent 大小测量,相当于解码后的 1024×1024 图像),以及移动设备上的运行时间 (在 iPhone 15 Pro 上测试)。下面将介绍改进模型的关键架构更改。
图2:高效 U-Net 架构。从 SDXL 的 U-Net 的更薄和更短的版本开始 (a),探索了一系列架构变化,即 (b)-(f),以在保留高质量生成性能的同时开发一个更小更快的模型
图3:高效 U-Net 各种设计选择的性能和效率比较。使用在 ImageNet-1K 上计算的 FID 来评估生成质量,生成 256px 的图。效率指标包括模型参数、时延和 FLOPs。FLOPs 和时延 (在 iPhone 15 Pro 上) 是用 128×128 latent 测量一次前向推理的,相当于解码后 1024×1024 的图像
1) 高分辨率阶段去掉 Self-Attention
Self-Attention 受二次计算复杂度的限制,对高分辨率输入会带来较高的计算成本和内存消耗。所以,只在最低分辨率保留 SA,在高分辨率阶段删除,如上图 2(b) 所示。这使得 FLOPs 减少了 17%,时延减少了 24%,结果如图 3 所示。有趣的是,甚至观察到性能改进,FID 从 3.76 降到 3.12。作者假设原因是高分辨率阶段 SA 的模型收敛得更慢。
2) 将 Conv 替换为扩展 Separable Conv
常规卷积 (Conv) 在参数和计算中都是多余的。为了解决这个问题,将所有 Conv 层替换为 Separable Conv
[11]
,由 Depthwise Convolution (DW) 和 Pointwise Convolution (PW) 组成,如图 2(c) 所示。这种替换将参数减少了 24%,时延减少了 62%,但也会导致性能下降 (FID 从 3.12 增加到 3.38)。为了解决这个问题,作者扩展了中间通道。具体来说,第 1 个 PW 层之后的通道数随着扩展比的增加而增加,在第 2 个 PW 层之后减少到原始数。Expansion ratio 设置为 2 以平衡性能、延迟和模型参数之间的权衡。这样的设计使得残差块与 Universal Inverted Bottleneck (UIB) 对齐。因此,本文模型在获得较低的 FID 的同时,实现了 15% 的参数、27% 的计算量和 2.4 倍的加速。
3) Trim FFN 层
对于 FFN 层,默认将隐藏通道 expansion ratio 设置为 4,并使用门控单元进一步加倍。这大大增加了模型参数、计算和内存使用。继 MobileDiffusion
[12]
之后,作者检查了简单地减少扩展比的功效,如图 2(d) 所示。本文表明,将扩展比减少到 3 可以保持可比的 FID 性能,同时将参数和 FLOP 减少 12%。
4) MQA 替换 MHSA
Multi-Head Self-Attention (MHSA) 为每个注意力头需要多组键和值。相比之下,Multi-Query Attention (MQA)
[13]
在所有 head 之间共享一组键和值更有效。用 MQA 替换 MHSA 将参数减少了 16%,时延减少了 9%,对性能的影响最小。有趣的是,减少时延的 9% 超过了减少 FLOPs 的 6%,因为减少的内存访问可实现更高的计算强度 (FLOPs/Byte)。因此,模型的计算吞吐量 FLOPS 提升了。所以,尽管 FLOPs 仅仅降低 6%,但是时延下降了 9%。
5) 将 Condition 注入第 1 阶段
交叉注意力 (Cross-Attention) 将 Condition 信息 (如纹理描述) 与空间特征混合,生成与条件一致的图像。然而,SDXL 的 UNet 仅在从第 2 阶段开始的 Transformer Block 中应用 CA,导致第 1 阶段的条件指导缺失。本文建议从第 1 阶段引入条件嵌入,如图 2(e) 所示。具体来说,将残差块替换为包含 CA 和 FFN 的 Transformer Block,而在没有 SA 层的情况下。这种调整使模型更小、更快、更高效,同时提高了 FID。
6) 使用 QK RMSNorm 和 RoPE 位置编码
作者扩展了最初为语言模型开发的两种先进技术,使用 RMSNorm
[14]
的 Query-Key (QK) Normalization
[15]
和 Rotary Position Embedding (RoPE)
[16]
,以增强模型 (图 2(f))。RMSNorm 在注意力机制中的 Query-Key 投影之后应用,在不牺牲模型表达能力的情况下降低了 Softmax 饱和的风险,同时稳定训练以实现更快的收敛。此外,作者将 RoPE 从一维调整为二维以更好地支持更高的分辨率,因为它显著减少了重复对象等伪影。总之,RMSNorm 和 RoPE 引入的计算和内存开销可以忽略不计,同时在 FID 性能方面提供了增益。
讨论
经过上述的优化,得到了一个高效而强大的扩散 Backbone,能够在移动设备上生成高分辨率图像。在进行大规模 T2I 训练之前,作者将本文模型与 ImageNet-1K 上的现有工作进行了比较。作者遵循先前工作的设置来训练 1000 Epochs。作者在不同的推理步骤中使用不同的 CFG 评估模型。如图 4 所示。高效的 U-Net 实现了与 SiT-XL 相当的 FID,而小近 45%。
图4:使用 CFG 在 ImageNet256×256 上进行类条件图像生成
1.3 更小更快的解码器
除了去噪模型外,解码器还占总运行时间的很大一部分,特别是对于 On-device 部署。在这里,作者介绍了一种新的解码器架构,如图 5 所示,以实现高效的高分辨率生成。
Baseline Decoder
由于优越的重建质量,作者使用来自 SD3
[17]
的自动编码器 (AE) 作为 Baseline 模型 (即来自 SD3 的 AE 的相同编码器)。AE 将图像
映射到低维 latent 空间
在 SD3 中为 8,16 )。然后通过 Decoder 将编码的 latent
解码回图像。对于高分辨率生成,作者观察到 SD3 中的解码器在移动设备上非常慢。具体来说,当在 iPhone 15 Pro 和移动 GPU 的 ANE处理器上生成 10242px 图像时,它会遇到内存不足(OOM)错误。为了克服延迟问题,作者提出了一个更小更快的 Decoder。
图5:(a) SDXL/SD3 解码器和 (b) 本文 tiny decoder 之间的架构比较
图6:Decoder 性能比较。PSNR 在 COCO 2017 验证集上计算。测量 FLOPs 和延迟 (在 iPhone 15 Pro 上),将 128×128 的潜在解码为 1024×1024 图像。SDXL 和 SD3 的 Decoder 无法在移动的神经引擎上运行
高效的 Decoder
作者进行了一系列实验来决定具有以下关键变化的高效 Decoder,与基线架构相比:
-
移除注意力层:
以大大减少峰值内存,而不会对解码质量产生显着影响。
-
保留最少的 GroupNorm (GN):
来找到延迟和性能之间的权衡(即减轻颜色移动)。
-
使解码器更薄
(更少的通道或更窄的宽度),并用 SepConvs 替换 Conv。
-
在高分辨率阶段
使用更少的 Residual Block
。
-
在 Residual Block 中
去除 Conv Shortcuts
,并使用 Upsampling 层进行通道转换。
Decoder 的训练
作者用均方误差 (MSE) 损失、lpips 损失、对抗性损失训练本文解码器,并丢弃 KL 项,因为 Encoder 是固定的。Decoder 在 256px 的图像 patch 上进行训练,Batch Size 为 256,迭代次数为 1M。Tiny Decoder 实现了具有竞争力的 PSNR 分数,与传统的解码器 (例如,来自 SDXL 和 SD3 的解码器) 相比,在移动设备上高分辨率生成速度提高了 35.9 倍和 54.4 倍。
On-device 时延的讨论
作者最后测量了 iPhone 16 Pro-Max 上 10242px 生成的 T2I 模型时延。Decoder 需要 119ms,U-Net 的每步延迟为 274ms。这导致 4 到 8 步生成的运行时间为 1.2~2.3s。注意,与其他组件相比,文本编码器运行时间可以忽略不计。
1.4 训练配方以及多级知识蒸馏
为了提高高效扩散模型的生成质量,本文提出了一系列训练技术。
基于流的训练和推理
Rectified Flows (RFs)
[18][19]
将正向过程定义为将数据分布连接到标准正态分布的直线路径,即:
其中,
是干净(潜在)图像,
是时间步长,
是时间步长相关因子,
是从
中采样的随机噪声。去噪 U-Net 被制定为预测目标速度场为:
其中,
是 U-Net 的预测速度。为了进一步增强训练稳定性,作者在训练期间对时间步长应用 logit 正态采样,将更多的样本分配给中间步骤。在推理阶段,使用 Flow-Euler 采样器,它根据速度预测下一个样本,即:
为了在高分辨率 (即 1024px) 图像上实现较低的信噪比,作者应用了类似于 SD3 的时间步长移位来调整训练和推理过程中的调度因子
。
多级知识蒸馏
为了提高小模型的生成质量,一种常见做法是应用知识蒸馏。得益于对齐的流匹配目标和 (AE) 潜在空间,强大的SD3.5-Large
[20]
模型可以作为教师进行输出蒸馏。然而,由于 1) U-Net 和 DiT 之间的异构架构,2) 蒸馏损失和任务损失之间的尺度差异,以及 3) 不同时间步的不同预测难度,仍然面临挑战。为了解决这些问题,本文提出了一种新的多级蒸馏损失,以及时间步长感知缩放以稳定和加速蒸馏。本文知识蒸馏的方案概述如图 7 所示。
除了式2 中定义的任务损失外,知识蒸馏的主要目标是使用教师模型
的输出来监督小模型
,可以表示为:
鉴于教师和学生模型之间的容量差距,单独应用输出级监督会导致不稳定和收敛速度慢。因此,作者进一步做了特征蒸馏:
其中,
和
分别表示教师模型和学生模型中第
层和第
层的特征输出。与之前的工作[21] 不同,本文考虑从 DiT 到 UNet 的跨结构蒸馏。由于 Transformer 最丰富的信息位于最后一层,作者将蒸馏目标设置为两个模型中的这一层,并使用只有 2 个 Conv 层的轻量级可训练投影
来映射学生特征以匹配教师特征的维度。所提出的特征级蒸馏损失为学生模型提供了额外的监督,更快地对齐教师模型的生成质量。
图7:多级知识蒸馏概述,作者执行输出蒸馏和特征蒸馏
时间步长感知缩放
加权多个目标一直是知识蒸馏的主要挑战,尤其是在扩散模型中。之前的工作
[4][21]
的总体训练目标是多个损失项的简单线性组合,即:
其中,加权系数
和
根据经验设置为常数。但是,这个 Baseline 设置无法考虑不同时间步长的预测难度。作者研究了模型训练期间不同时间步
下
和
的幅度分布。可以发现在中间步骤中,与更接近 0 或 1 的
相比,预测难度较低。
图8:task loss 和 kd loss 的平均损失幅值
基于这一重要观察,作者提出了目标的时间步长感知缩放,以缩小不同
值的损失幅值的差距,并考虑每个时间步的预测困难,如下所示:
其中,
是标准归一化 logit-norm 密度函数,
表示幅值。在
中,首先确保不同
的任务损失和蒸馏损失之间的相同比例,然后预测难度更高(
更接近 0 或 1 )时使用更多的教师模型监督,预测难度更低(中间的时间步)时使用更多真实数据监督。这个方案考虑了时间步
的变化,有助于加速蒸馏训练。最终的多级蒸馏目标
可以定义为:
1.5 Step 蒸馏
本文通过基于分布匹配的 Step 蒸馏方案进一步提高模型的采样效率。借助 Latent Adversarial Diffusion Distillation (LADD)
[22]
的方案,作者使用 diffusion-GAN 混合结构蒸馏本文模型,使之变为更少的 steps,优化目标为:
其中,
是使用预训练的 Few-step 教师模型
(SD3.5-Large-Turbo
[8]
) 部分初始化的判别器模型。教师模型仅用作特征提取器,并在蒸馏过程中被冻结。在特征提取后,只训练判别器的几个线性层。
采样过程是
和