文章链接:
https://arxiv.org/pdf/2409.06633
项目链接:
https://sjtuplayer.github.io/projects/SaRA/
1.引言
SaRA
是一种针对预训练扩散模型的高效微调方法。通过微调预训练扩散模型中的无效参数,赋予模型对下游任务的处理能力。SaRA能够显著节省计算显存开销与代码复杂度,仅修改一行训练代码即可实现微调过程。该方法的核心创新在于:
-
参数重要性分析
:SaRA首先对预训练模型中的参数重要性进行分析,发现预训练扩散模型中绝对值最小的10%至20%的参数在生成过程中的作用微乎其微。并且这些参数的无效性并非模型固有属性,而是由于训练过程中的不稳定性导致。
-
稀疏低秩训练
:基于上述发现,SaRA提出利用这些暂时无效的参数,通过优化稀疏权重矩阵来学习特定任务的知识。为了避免过拟合,SaRA采用了基于核范数的低秩稀疏训练方案,有效约束了学习过程中的参数秩。
-
渐进式参数调整策略
:SaRA设计了一种参数重调整策略,通过在微调过程中重定向可训练参数,确保几乎所有参数都能有效地贡献于新任务的学习。
-
非结构化反向传播策略
:SaRA提出了一种新颖的反向传播策略,显著降低了微调过程中的内存成本。
SaRA在多个下游任务上进行了广泛的实验验证,包括基模型能力提升、下游数据微调、图像定制化、可控视频生成等。实验结果表明SaRA不仅能够提升基础模型在原始任务的生成能力,在下游任务中,能兼顾下游任务的学习以及预训练先验的维护,实现优越的模型微调效果。
2. 参数重要性分析
2.1 预训练模型中的无效参数
在深度学习模型中,参数的数量往往非常庞大,但根据模型剪枝理论,并非所有参数都对模型的输出有积极的影响。作者首先研究了多个版本的预训练Stable Diffusion(包括1.4,1.5,2.0,与3.0)中,绝对值权重较小的参数对生成结果的影响。通过将绝对值权重小于
的参数置为0后,让模型根据GPT-4o生成的1000个文本,生成对应的1000张图像,计算生成图像的FID指标以及CLIP Score(如图图 1所示),发现将模型10%~20%的参数置为0时,模型的生成结果并没有受到影响,甚至有些情况下还能略微提升,证明了这些小参数在预训练Stable Diffusion模型中的无效性。
图 1:Stable Diffusion预训练模型参数分布与小参数对生成结果的影响
2.2 无效参数的潜在有效性
2.1中导致无效参数的原因可能有两个:一是由于模型结构设计的原因,这些参数天生就是冗余、无效的参数,因此无法在训练过程中起到作用,另外一个原因则可能是由于模型训练过程中的随机性,导致这些参数恰好在训练结束的时候停留在0附近。因此,作者进一步对参数无效的原因展开研究。选取了Stable Diffusion在FFHQ的预训练模型,标记了初始权重最小的1%参数,将该模型继续在FFHQ上训练,并在训练过程中实时跟踪这1%参数的变化,结果如图 2所示,可见,随着训练的进行,初始的小参数(蓝色线条)逐渐跳出了1%的阈值,而初始大于1%阈值的参数,大部分跌入了1%以内,并且小于该阈值的参数总量始终维持在1%左右,证明了在训练过程中,所有参数都以一定的概率跳出或者跌入1%阈值中,说明初始的小参数是由训练过程的随机性导致的,因此,可以在微调过程中利用这些暂时无效的参数,赋予模型针对下游任务的生成能力。
图 2:训练过程中权重绝对值小于初始1%阈值θ_t的参数分布变化
3. 方法介绍
为了充分利用这些暂时无效的参数,SaRA提出了一种渐进式稀疏低秩训练方法。该方法的核心思想是,通过对这些无效参数进行微调,使其在下游任务中发挥作用。具体来说,SaRA首先确定一个阈值
,将小于该阈值的参数视为暂时无效的参数。然后,通过优化一个稀疏权重矩阵,使得这些参数能够学习到新任务的知识。为了避免在训练过程中出现过拟合,SaRA引入了基于核范数的低秩约束。核范数是一种用于估计矩阵秩的凸松弛,通过最小化核范数,可以有效地限制稀疏矩阵的秩,从而避免模型在训练过程中学习到过多的噪声信息。
3.1 稀疏矩阵训练
SaRA致力于微调预训练模型中暂时无效的参数(即权重绝对值小于一定阈值
的参数),使预训练的扩散模型适应下游任务,同时尽可能地保留原始的模型先验。具体而言,首先为初始参数P计算一个稀疏掩码M,满足:
SaRA基于该稀疏掩码来更新初始无效的参数
, 同时保持初始有效参数
不变。在训练期间, 对于所有参数的梯度
, 利用该稀疏掩码
来保留所需要更新参数的梯度, 并更新相应的可训练参数
3.2 基于核范数的低秩约束
稀疏参数矩阵可能会具有较高的秩,从而导致过于强大的表征能力,使得模型在下游任务训练过程中出现过拟合的问题。为了缓解这个问题,我们在稀疏矩阵上引入了基于核范数的低秩约束,以防止约束系数矩阵的秩。低秩约束的一种直接方法是最小化稀疏参数矩阵的秩Rank(P_M )。然而,由于其矩阵秩求解的非凸性,
难以实现可微的求解。因此,我们使用矩阵的核范数来估计系数参数矩阵的秩:
其中
是参数矩阵
的第 i 个奇异值, 通过最小化该式子, 可以实现对参数矩阵的低质约束。
为了计算核范数
, 对参数矩阵进行奇异值分解
, 其中 U 和 V 是正交矩阵,
是包含奇异值
的对角矩阵。
关于
的次梯度可以由下式求出:
基于核范数梯度的推导, 可以低质约束的可微性, 从而实现基于核范数的低秩约束损失:
3.3 渐进式参数调整
在模型的微调过程中,由于训练的随机性,仍然会存在部分参数停留在阈值以下,尤其是微调过程的总轮次往往较少,导致最终存在一部分的参数仍然无效。如图 2 所示,初始的小参数在训练初期会快速跳出阈值,而后期的趋势逐渐放缓,当微调轮次较少时,可训练参数中可能存在15%的参数仍然无效。因此,SaRA提出渐进式的参数调整策略,在微调的前半阶段,首先对初始的无效参数进行训练,使得大部分的无效参数跳出阈值,而在后半阶段,再次对剩余的无效参数进行训练,使其快速跳出阈值。通过这种分阶段的渐进式训练策略,SaRA可以更有效地利用这些无效参数,提高模型在新任务上的性能。
3.4 非结构化反向传播策略
目前,基于 LoRA 的方法和参数选择的高效微调方法都对计算资源造成了沉重的负担:1)对于基于 LoRA 的方法,由于 LoRA 模块的引入,它不需要存储模型参数的梯度,但却需要额外的内存成本来存储 LoRA 模块中的中间变量,如图3 (a) 所示。2)对于参数选择的方法,一个一直困扰的问题是:它们需要与全参数微调相同甚至更多的计算资源(尤其是 GPU 显存)。虽然它们只对模型参数的子集进行微调,但它们保留了整个参数矩阵P的梯度,因为主流的深度学习库(如PyTorch和TensorFlow)只支持对整个参数矩阵的梯度反向传播。因此,以往的基于参数选择的方法必须在整个参数矩阵P上执行梯度反向传播,然后使用预先计算的掩码矩阵
通过
屏蔽不需要的参数梯度,并通过
实现整体参数的更新(如图4(b) 所示)。这种方法需要存储所有模型参数的梯度和额外的掩码矩阵,导致比全参数微调更大的计算资源需求。因此,为了解决这些问题,SaRA提出了非结构化梯度回传策略, 通过将可训练参数从参数矩阵中剥离与非结构化前向和反向传播, 实现训练过程中显存的大幅降低。
具体地, SaRA首先将训练模型的所有参数变为非叶节点, 并通过系数矩阵
, 获取可学习参数
, 将可学习参数作为真实的叶节点。定义非结构化映射函数
, 在前向过程中将可学习参数映射到模型参数中:
在反向过程中,定义非结构化反向传播函数
,将模型参数的梯度自动回传至可训练参数:
由于模型参数成为了非叶子节点,