导读
本文是VCC刘烨同学对论文 Improving Visual Prompt Tuning by Gaussian Neighborhood Minimization for Long-Tailed Visual Recognition 的解读,该工作来自深圳大学可视计算研究中心及光明实验室黄惠教授课题组,和厦门大学、广东工业大学及香港浸会大学联合研究,已被机器学习顶级会议 NeurIPS 2024 录用,同时获得中国发明专利授权和软件著作权登记。
https://vcc.tech/research/2024/GNM-PT
该工作提出了
一种针对长尾问题的训练优化策略
,旨在平衡地提升视觉提示词微调对各个类别的泛化能力。
此训练优化策略新提出的基于高斯邻域最小化的损失,能够帮助模型在长尾数据上训练时收敛到更平坦的损失极小值点,平衡地提升模型对头类和尾类的泛化能力,并且几乎不引入额外的计算代价。
大量实验证明,提出的高斯邻域最小化方法能够使得模型在长尾分布数据上的损失平面更加平坦,且几乎不增加额外的计算开销。
该方法有效平衡了模型对头类和尾类的泛化能力,并在多个长尾任务中展现出卓越的性能和效率优势
。
从真实世界中采集的数据通常呈现长尾分布,其中少数类别(头部类)拥有丰富的样本,而大量类别(尾类)则仅占据极少的样本。
这种不平衡的分布对深度学习模型的训练构成了严重障碍。
因此近年来,长尾视觉识别问题引起了广泛关注,并促使研究者提出了许多有效的解决方案。
大多数现有方法集中于从头开始训练模型,主要从数据处理、表征能力提升和模型输出修正等角度着手,试图缓解长尾问题。
近期,一些研究开始探索在微调预训练模型的基础上进行长尾视觉识别的改进[1]这些方法借助参数有效微调 (PEFT) 技术和更具鲁棒性的预训练模型,取得了良好的性能。
然而,即使引入了大规模预训练知识,使用视觉提示词微调 (VPT) [2]等PEFT技术时,模型在尾类上的泛化能力依然远逊于头类。
Sharpness-Aware Minimization (SAM) [3]优化器能够使模型在训练过程中收敛到平坦的损失极小值点,从而提高其泛化能力。
然而,在长尾数据上应用SAM时,模型优化通常由头类主导,忽略了尾类的贡献。
此外,SAM需要计算两次梯度,带来了额外的计算代价。
因此,迫切需要一种能够提升模型在长尾数据上泛化能力且计算高效的方法。
本论文提出了一种针对长尾数据分布提升VPT泛化能力的新方法 — Gaussian Neighborhood Minimization Prompt Tuning (GNM-PT)。该方法的核心原理基于Sharpness-Aware Minimization (SAM),通过使损失平面更加平坦来增强模型的泛化能力。SAM优化器在训练模型时,通过最小化当前参数邻域内的最大损失值,使得模型极小值点附近的损失平面更平坦。然而,由于长尾数据中大量头类样本的主导,SAM优化策略使得修正后的梯度方向更偏向于优化头类。为了解决这一问题,GNM-PT提出了一种新的优化策略
—
Gaussian neighborhood minimization (GNM)。与SAM不同,GNM在优化过程中仅需要计算一次梯度,避免了额外的计算开销。通过最小化损失平面中高斯邻域内采样点的损失,GNM能够使模型收敛到一个平坦且不受头类主导的损失极小值点,从而平衡地提升模型对所有类别的泛化能力。此外,GNM-PT还进一步利用提示词中的信息,增强了分类器的鲁棒性。图1展示了损失平面[4]的可视化结果。图1 (a) 表明,GNM与SAM在效果上相似,能够使损失极小值点附近的损失平面更平坦。图1 (b) 表明,在长尾分布数据上,GNM凸性更好,进一步提高了模型的泛化能力。
GNM-PT方法使用VPT微调预训练的ViT模型[5],利用GNM优化器更新VPT的参数,使模型参数收敛到平坦且不受头类主导的损失极小值点,提升泛化能力;另外,将高水平提示词中的信息融合进ViT最后输出的特征中,增强分类性能。
GNM在长尾分布数据上优化模型时,通过最小化损失平面中高斯邻域内采样点的损失来更新模型参数。首先,从正态分布中采样出一个随机向量
并利用
和高斯邻域半径
生成高斯扰动
如公式所示:
之后,利用
在长尾分布上计算出当前参数高斯邻域内采样点的损失,并最小化该损失,更新模型参数,其过程如公式所示:
按照上述GNM的优化步骤进行训练,最终参数便可收敛到一个平坦且不受头类主导的损失极小值点。整个GNM优化器的参数更新示意图如图2所示:
(
和
分别表示在
轮参数更新时未使用和使用GNM情况下的梯度更新)
使用VPT微调预训练ViT模型时,提示词中也编码了大量与当前任务相关的信息。为了进一步提升分类性能,GNM-PT按照下面公式所示的方式进行提示词信息融合:
将最后一层Transformer block的提示词信息
融合进最后一层输出的
中得到
作为ViT最终输出的特征,再将
送入分类器中进行分类。
为了证明GNM在长尾分布上的优势,我们分别使用SAM和GNM两种优化器,在长尾分布数据上利用GCL[6]损失函数训练模型,并可视化两者的损失平面,结果如图3所示。
可以看出,GNM能使模型得到更小的损失值,且损失平面几乎没有波动,有助于提高模型泛化能力。
为了验证GNM方式在分类精度和计算效率上的优势,我们在保证其他设置相同的情况下分别使用SAM和GNM两种优化器进行训练,对比优化器的执行时间和得到的模型的分类精度。表1中的结果表明,SAM的计算时间比基线方法超出了1.8倍。相比之下,GNM只增加了不到两秒的计算时间,几乎可忽略不计,同时还能够提升分类精度。
图4中统计了SAM和GNM两种优化器分别对不同类的精度影响。
可以看出,在使用GCL损失时,SAM降低了模型对于尾类的性能,而论文提出GNM平衡地提升了模型对所有类的性能,更适合解决长尾问题。
为了展示GNM-PT方法与其他先进的长尾视觉识别算法的性能对比,我们在常见的长尾数据集上进行实验,其结果如表2-4所示。GNM-PT在各个数据集上均展现出了较好的分类性能。
表2 在CIFAR100-LT上的top-1分类精度 (%)