专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
CareerIn投行PEVC求职  ·  说阿里投资DeepSeek的 都是在那YY呢 ·  2 天前  
人力资源管理  ·  单位里,和领导独处,少说这4种话:1、不要没 ... ·  2 天前  
HR成长社  ·  人事工作节点安排表.xls ·  2 天前  
51好读  ›  专栏  ›  极市平台

NeurIPS 2024 | 超越KL!大连理工提出WKD:基于WD距离的知识蒸馏新方法

极市平台  · 公众号  ·  · 2025-01-15 22:00

正文

↑ 点击 蓝字 关注极市平台
作者丨新智元
来源丨新智元
编辑丨极市平台

极市导读

大连理工大学的研究人员提出了一种基于Wasserstein距离的知识蒸馏方法,克服了传统KL散度在Logit和Feature知识迁移中的局限性,在图像分类和目标检测任务上表现更好。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

自Hinton等人的开创性工作以来,基于Kullback-Leibler散度(KL-Div)的知识蒸馏一直占主导地位。

然而,KL-Div仅比较教师和学生在相应类别上的概率,缺乏跨类别比较的机制,应用于中间层蒸馏时存在问题,其无法处理不重叠的分布且无法感知底层流形的几何结构。

为了解决这些问题,大连理工大学的研究人员提出了一种基于Wasserstein距离(WD)的知识蒸馏方法。所提出方法在图像分类和目标检测任务上均取得了当前最好的性能,论文已被NeurIPS 2024接受为Poster

论文地址: https://arxiv.org/abs/2412.08139

项目地址: https://peihuali.org/WKD/

代码地址:https://github.com/JiamingLv/WKD

背景与动机介绍

知识蒸馏(KD)旨在将具有大容量的高性能教师模型中的知识迁移到轻量级的学生模型中。近年来,知识蒸馏在深度学习中受到了越来越多的关注,并取得了显著进展,在视觉识别、目标检测等多个领域得到了广泛应用。在其开创性工作中,Hinton等人引入了Kullback-Leibler散度(KL-Div)用于知识蒸馏,约束学生模型的类别概率预测与教师模型相似。

从那时起,KL-Div在Logit蒸馏中占据主导地位,并且其变体方法DKD、NKD等也取得了令人瞩目的性能。此外,这些Logit蒸馏方法还可以与将知识从中间层传递的许多先进方法相互补充。

尽管KL-Div取得了巨大的成功,但它存在的两个缺点阻碍了教师模型知识的迁移。

首先,KL-Div仅比较教师和学生在相应类别上的概率,缺乏执行跨类别比较的机制。

然而,现实世界中的类别呈现不同程度的视觉相似性,例如,哺乳动物物种如狗和狼彼此间的相似度较高,而与汽车和自行车等人工制品则有很大的视觉差异,如图1所示。

不幸的是,由于KL-Div是类别对类别的比较,KD和其变体方法无法显式地利用这种丰富的跨类别知识。

图1 左图使用t-SNE展示了100个类别的嵌入分布。可以看出,这些类别在特征空间中表现出丰富的相互关系 (IR)。然而,右图中的KL散度无法显式地利用这些相互关系

其次,KL-Div在用于从中间层特征进行知识蒸馏时存在局限性。图像的深度特征通常是高维的且空间尺寸较小,因此其在特征空间中非常稀疏,不仅使得KL-Div在处理深度神经网络特征的分布时存在困难。

KL-Div无法处理不重叠的离散分布,并且由于其不是一个度量,在处理连续分布时能力有限,无法感知底层流形的几何结构。

图2 基于Wasserstein距离(WD)的知识蒸馏方法的总览图

为了解决这些问题,研究人员提出了一种基于Wasserstein距离的知识蒸馏方法,称为WKD,同时适用于Logit蒸馏(WKD-L)和Feature蒸馏(WKD-F),如图2所示。

在WKD-L中,通过离散WD最小化教师和学生之间预测概率的差异,从而进行知识转移。

通过这种方式,执行跨类别的比较,能够有效地利用类别间的相互关系(IRs),与KL-Div中的类别间比较形成鲜明对比。

对于WKD-F,研究人员利用WD从中间层特征中蒸馏知识,选择参数化方法来建模特征的分布,并让学生直接匹配教师的特征分布。

具体来说,利用一种最广泛使用的连续分布(高斯分布),该分布在给定特征的1阶和2阶矩的情况下具有最大熵。

论文的主要贡献可以总结如下:

  1. 提出了一种基于离散WD的Logit蒸馏方法(WKD-L),可以通过教师和学生预测概率之间的跨类别比较,利用类别间丰富的相互关系,克服KL-Div无法进行类别间比较的缺点。

  2. 将连续WD引入中间层进行Feature蒸馏(WKD-F),可以有效地利用高斯分布的Riemann空间几何结构,优于无法感知几何结构的KL-Div。

  3. 在图像分类和目标检测任务中,WKD-L优于非常强的基于KL-Div的Logit蒸馏方法,而WKD-F在特征蒸馏中优于KL-Div的对比方法和最先进的方法。WKD-L和WKD-F的结合可以进一步提高性能。

用于知识迁移的WD距离

用于Logit蒸馏的离散WD距离

类别之间的相互关系(IRs)

如图1所示,现实世界中的类别在特征空间中表现出复杂的拓扑关系。相同类别的特征会聚集并形成一个分布,而相邻类别的特征有重叠且不能完全分离。

因此,研究人员提出基于CKA量化类别间的相互关系(IRs),CKA是一种归一化的Hilbert-Schmidt独立性准则(HSIC),通过将两个特征集映射到再生核希尔伯特空间(RKHS)来建模统计关系。

首先将每个类别中所有训练样本的特征构成一个特征矩阵,之后通过计算任意两个类别特征矩阵之间的CKA得到类间相互关系(IR)。计算IR的成本可以忽略,因为在训练前仅需计算一次。

由于教师模型通常包含更丰富的知识,因此使用教师模型来计算类别间的相互关系图片

损失函数

分别表示教师模型和学生模型的预测类别概率,其通过softmax函数和温度对Logit计算得到。将离散的WD表示为一种熵正则化的线性规划:

其中 分别表示每单位质量的运输成本和在将概率质量从 移动到

时的运输量; 是正则化参数。

定义运输成本 与相似度度量 成负相关。

因此,WKD-L的损失函数可以定义为:

用于Feature蒸馏的连续WD距离

特征分布建模

将模型某个中间层输出的特征图重塑为一个矩阵,其中第 i 列 表示一个空间特征。

之后,估计这些特征的一阶矩 和二阶矩 ,并将二者作为高斯分布的参数来建模输入图像特征的分布。

损失函数

设教师的特征分布为高斯分布

设教师的特征分布为高斯分布 类似地,学生的分布记为

两者之间的连续Wasserstein距离(WD)定义为:

其中, 是高斯变量,q表示联合分布。 最小化上式可以得到闭集形式的WD距离。







请到「今天看啥」查看全文