其中
ϵ
θ
是学习到的噪声预测模型,
α
t
是步长,
σ
t
是噪声方差,
𝐳
∼
𝒩
(
0
,
I
)
。
这种基于检索增强的扩散过程确保生成的图像既保持高保真度又具有事实准确性。
最近的方法,例如基于检索增强的扩散模型(RDM)
[27]
和kNN-Diffusion
[29]
,已经证明了这种方法的有效性,显著提高了生成图像的真实感和上下文一致性。
图3:
Magic 1-For-1的整体架构。
图像先验注入和多模态引导设计
图像到视频 (I2V) 任务涉及使用输入图像作为第一帧生成与给定文本描述一致的视频。
具体而言,T2V模型将Shape
T
×
C
×
H
×
W
的潜在张量
X
作为输入,其中
T
,
C
,
H
T4>和
W
分别对应于压缩视频的帧,通道,高度和宽度。
与 Emu Video
[10]
类似,为了结合图像条件
I
,我们将
I
作为视频的第一帧,并应用零填充来构建维度为
T
×
C
×
H
×
W
的张量
I
o
,如图
3
所示。
此外,我们引入了一个形状为
T
×
1
×
H
×
W
的二元掩码
m
,其中第一个时间位置设置为 1,所有后续位置设置为零。
然后将潜在张量
X
、填充张量
I
o
和掩码
m
沿通道维度连接起来,形成模型的输入。
如图
3
所示,由于输入张量的通道维度从
C
增加到
2
C
+
1
,我们调整模型第一个卷积模块的参数从
ϕ
=
(
C
i
n
(
=
C
)
,
C
o
u
t
,
s
h
,
s
w
)
到
ϕ
′
=
(
C
i
n
′
(
=
2
C
+
1
)
,
C
o
u
t
,
s
h
,
s
w
)
。
这里,
C
i
n
和
C
i
n
′
分别表示修改前后的输入通道数,
C
o
u
t
是输出通道数,
s
h
和
s
w
分别对应卷积核的高度和宽度。
为了保持 T2V 模型的表示能力,
ϕ
′
的前
C
个输入通道从
ϕ
复制,而其他通道初始化为零。
I2V 模型在与 T2V 模型相同的的数据集上进行预训练,以确保一致性。
扩散模型推理的迭代性质,其特征在于其多步采样过程,给推理速度带来了显著瓶颈。
在大型模型(例如我们的 130 亿参数扩散模型 Magic 1-For-1)中,这个问题尤其严重,因为每个单独采样步骤的计算成本很高。
如图
4
所示,我们通过实现一种双重蒸馏方法来应对这一挑战,该方法结合了步骤蒸馏和CFG蒸馏以实现更快的采样。
对于步骤蒸馏,我们利用DMD2,这是一种针对高效分布对齐和加速采样而设计的最新算法。
受分数蒸馏采样(SDS)
[25]
的启发,DMD2通过一个涉及三个不同模型的协调训练范式来促进步骤蒸馏。
这些模型包括:一/四步生成器
G
ϕ
,其参数被迭代优化;真实视频模型
u
θ
real
,其任务是逼近底层真实数据分布
p
real
;以及伪造视频模型
u
θ
fake
,其估计生成的(伪造)数据分布
p
fake
。
至关重要的是,所有三个模型都从同一个预训练模型初始化,确保一致性并简化训练过程。
步骤蒸馏的分布匹配目标可以用数学表达式表示为:
这里,
𝐳
t
代表时间步
t
处的视频潜在变量,
z
t
=
σ
t
z
1
+
(
1
−
σ
t
)
z
^
0
,其中
z
^
0
表示由几步生成器合成的输出,
σ
t
表示噪声调度。
此公式将传统的基于分数函数的分布匹配(标准DMD2中固有)重新构建为一种新颖的方法,该方法侧重于时间步
t
=
0
处的分布对齐。
此调整对于确保与Magic 1-For-1中使用的训练方法一致至关重要。
此外,DMD2需要实时更新
u
θ
fake
以保证对伪造数据分布
p
fake
的准确逼近。
此更新由以下损失函数控制:
在扩散模型的推理阶段,无分类器扩散引导 (CFG)
[14, 6]
经常在每个采样步骤中使用。
CFG 通过在 dropout 条件下执行额外的计算,提高了生成结果相对于指定条件的保真度。
为了消除这种计算开销并提高推理速度,我们实现了 CFG 蒸馏
[21]
。
我们定义了一个蒸馏目标,训练学生模型
u
θ
s
直接生成引导输出。
特别地,我们将以下关于时间步长和引导强度的期望最小化:
其中
表示条件输出和无条件输出之间线性插值的预测。
T
s
代表文本提示。
在这种公式中,
p
w
(
w
)
=
𝒰
[
w
min
,
w
max
]
表示引导强度参数在训练期间均匀采样,这使蒸馏模型能够有效地处理各种引导尺度,而无需重新训练。
为了整合引导权重
w
,我们将其作为额外输入提供给我们的学生模型。
此蒸馏过程有效地将传统的 CFG 计算压缩为单个简化的前向传递。
我们将整体蒸馏目标
L
distillation
构建为两个损失项的加权和。
CFG 蒸馏损失用于使学生模型的输出与教师的引导预测对齐,而基础预测损失则确保学生保持教师的底层生成能力。
因此,完整的蒸馏损失由下式给出:
其中
w
bf16
表示原始 bfloat16 权重,
w
i
n
t
8
表示量化后的 int8 权重,
max
(
|
w
bf16
|
)
表示权重张量中的最大绝对值。
在实践中,可以使用更复杂的方法,例如每通道量化或量化意识培训,以提高性能。
为减轻潜在的 CUDA 错误并在推理过程中确保数值稳定性,模型中的所有线性层在使用量化后的 int8 权重进行矩阵乘法之前,会先将其输入转换为 bfloat16。
这种 bfloat16-int8 乘法有助于保持精度,同时仍然受益于 int8 权重减少的内存占用。