专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
51好读  ›  专栏  ›  极市平台

修改一行代码就能实现高效微调!上海交大 & 腾讯开源SaRA:兼顾原始生成和下游任务

极市平台  · 公众号  · 科技自媒体  · 2024-09-19 22:00

正文

↑ 点击 蓝字 关注极市平台
作者丨AI生成未来
来源丨AI生成未来
编辑丨极市平台

极市导读

仅修改一行训练代码即可实现微调过程。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

文章链接: https://arxiv.org/pdf/2409.06633

项目链接: https://sjtuplayer.github.io/projects/SaRA/

1.引言

SaRA 是一种针对预训练扩散模型的高效微调方法。通过微调预训练扩散模型中的无效参数,赋予模型对下游任务的处理能力。SaRA能够显著节省计算显存开销与代码复杂度,仅修改一行训练代码即可实现微调过程。该方法的核心创新在于:

  • 参数重要性分析 :SaRA首先对预训练模型中的参数重要性进行分析,发现预训练扩散模型中绝对值最小的10%至20%的参数在生成过程中的作用微乎其微。并且这些参数的无效性并非模型固有属性,而是由于训练过程中的不稳定性导致。
  • 稀疏低秩训练 :基于上述发现,SaRA提出利用这些暂时无效的参数,通过优化稀疏权重矩阵来学习特定任务的知识。为了避免过拟合,SaRA采用了基于核范数的低秩稀疏训练方案,有效约束了学习过程中的参数秩。
  • 渐进式参数调整策略 :SaRA设计了一种参数重调整策略,通过在微调过程中重定向可训练参数,确保几乎所有参数都能有效地贡献于新任务的学习。
  • 非结构化反向传播策略 :SaRA提出了一种新颖的反向传播策略,显著降低了微调过程中的内存成本。

SaRA在多个下游任务上进行了广泛的实验验证,包括基模型能力提升、下游数据微调、图像定制化、可控视频生成等。实验结果表明SaRA不仅能够提升基础模型在原始任务的生成能力,在下游任务中,能兼顾下游任务的学习以及预训练先验的维护,实现优越的模型微调效果。

2. 参数重要性分析

2.1 预训练模型中的无效参数

在深度学习模型中,参数的数量往往非常庞大,但根据模型剪枝理论,并非所有参数都对模型的输出有积极的影响。作者首先研究了多个版本的预训练Stable Diffusion(包括1.4,1.5,2.0,与3.0)中,绝对值权重较小的参数对生成结果的影响。通过将绝对值权重小于 的参数置为0后,让模型根据GPT-4o生成的1000个文本,生成对应的1000张图像,计算生成图像的FID指标以及CLIP Score(如图图 1所示),发现将模型10%~20%的参数置为0时,模型的生成结果并没有受到影响,甚至有些情况下还能略微提升,证明了这些小参数在预训练Stable Diffusion模型中的无效性。

图 1:Stable Diffusion预训练模型参数分布与小参数对生成结果的影响

2.2 无效参数的潜在有效性

2.1中导致无效参数的原因可能有两个:一是由于模型结构设计的原因,这些参数天生就是冗余、无效的参数,因此无法在训练过程中起到作用,另外一个原因则可能是由于模型训练过程中的随机性,导致这些参数恰好在训练结束的时候停留在0附近。因此,作者进一步对参数无效的原因展开研究。选取了Stable Diffusion在FFHQ的预训练模型,标记了初始权重最小的1%参数,将该模型继续在FFHQ上训练,并在训练过程中实时跟踪这1%参数的变化,结果如图 2所示,可见,随着训练的进行,初始的小参数(蓝色线条)逐渐跳出了1%的阈值,而初始大于1%阈值的参数,大部分跌入了1%以内,并且小于该阈值的参数总量始终维持在1%左右,证明了在训练过程中,所有参数都以一定的概率跳出或者跌入1%阈值中,说明初始的小参数是由训练过程的随机性导致的,因此,可以在微调过程中利用这些暂时无效的参数,赋予模型针对下游任务的生成能力。

图 2:训练过程中权重绝对值小于初始1%阈值θ_t的参数分布变化

3. 方法介绍

为了充分利用这些暂时无效的参数,SaRA提出了一种渐进式稀疏低秩训练方法。该方法的核心思想是,通过对这些无效参数进行微调,使其在下游任务中发挥作用。具体来说,SaRA首先确定一个阈值 ,将小于该阈值的参数视为暂时无效的参数。然后,通过优化一个稀疏权重矩阵,使得这些参数能够学习到新任务的知识。为了避免在训练过程中出现过拟合,SaRA引入了基于核范数的低秩约束。核范数是一种用于估计矩阵秩的凸松弛,通过最小化核范数,可以有效地限制稀疏矩阵的秩,从而避免模型在训练过程中学习到过多的噪声信息。

3.1 稀疏矩阵训练

SaRA致力于微调预训练模型中暂时无效的参数(即权重绝对值小于一定阈值 的参数),使预训练的扩散模型适应下游任务,同时尽可能地保留原始的模型先验。具体而言,首先为初始参数P计算一个稀疏掩码M,满足:

SaRA基于该稀疏掩码来更新初始无效的参数 , 同时保持初始有效参数 不变。在训练期间, 对于所有参数的梯度 , 利用该稀疏掩码 来保留所需要更新参数的梯度, 并更新相应的可训练参数

3.2 基于核范数的低秩约束

稀疏参数矩阵可能会具有较高的秩,从而导致过于强大的表征能力,使得模型在下游任务训练过程中出现过拟合的问题。为了缓解这个问题,我们在稀疏矩阵上引入了基于核范数的低秩约束,以防止约束系数矩阵的秩。低秩约束的一种直接方法是最小化稀疏参数矩阵的秩Rank(P_M )。然而,由于其矩阵秩求解的非凸性, 难以实现可微的求解。因此,我们使用矩阵的核范数来估计系数参数矩阵的秩:

其中 是参数矩阵 的第 i 个奇异值, 通过最小化该式子, 可以实现对参数矩阵的低质约束。

为了计算核范数 , 对参数矩阵进行奇异值分解 , 其中 U 和 V 是正交矩阵, 是包含奇异值 的对角矩阵。 关于 的次梯度可以由下式求出:

基于核范数梯度的推导, 可以低质约束的可微性, 从而实现基于核范数的低秩约束损失:

3.3 渐进式参数调整

在模型的微调过程中,由于训练的随机性,仍然会存在部分参数停留在阈值以下,尤其是微调过程的总轮次往往较少,导致最终存在一部分的参数仍然无效。如图 2 所示,初始的小参数在训练初期会快速跳出阈值,而后期的趋势逐渐放缓,当微调轮次较少时,可训练参数中可能存在15%的参数仍然无效。因此,SaRA提出渐进式的参数调整策略,在微调的前半阶段,首先对初始的无效参数进行训练,使得大部分的无效参数跳出阈值,而在后半阶段,再次对剩余的无效参数进行训练,使其快速跳出阈值。通过这种分阶段的渐进式训练策略,SaRA可以更有效地利用这些无效参数,提高模型在新任务上的性能。

3.4 非结构化反向传播策略

目前,基于 LoRA 的方法和参数选择的高效微调方法都对计算资源造成了沉重的负担:1)对于基于 LoRA 的方法,由于 LoRA 模块的引入,它不需要存储模型参数的梯度,但却需要额外的内存成本来存储 LoRA 模块中的中间变量,如图3 (a) 所示。2)对于参数选择的方法,一个一直困扰的问题是:它们需要与全参数微调相同甚至更多的计算资源(尤其是 GPU 显存)。虽然它们只对模型参数的子集进行微调,但它们保留了整个参数矩阵P的梯度,因为主流的深度学习库(如PyTorch和TensorFlow)只支持对整个参数矩阵的梯度反向传播。因此,以往的基于参数选择的方法必须在整个参数矩阵P上执行梯度反向传播,然后使用预先计算的掩码矩阵 通过 屏蔽不需要的参数梯度,并通过 实现整体参数的更新(如图4(b) 所示)。这种方法需要存储所有模型参数的梯度和额外的掩码矩阵,导致比全参数微调更大的计算资源需求。因此,为了解决这些问题,SaRA提出了非结构化梯度回传策略, 通过将可训练参数从参数矩阵中剥离与非结构化前向和反向传播, 实现训练过程中显存的大幅降低。

具体地, SaRA首先将训练模型的所有参数变为非叶节点, 并通过系数矩阵 , 获取可学习参数 , 将可学习参数作为真实的叶节点。定义非结构化映射函数 , 在前向过程中将可学习参数映射到模型参数中:

在反向过程中,定义非结构化反向传播函数 ,将模型参数的梯度自动回传至可训练参数:

由于模型参数成为了非叶子节点,







请到「今天看啥」查看全文


推荐文章
果迷之家  ·  iOS 10.3 正式发布:你需要知道这些
7 年前
房地产投资融资俱乐部  ·  木子美-买房就跟约炮一样,是刚需?
7 年前