论文
:
https://arxiv.org/pdf/2403.01427.pdf
代码
:
https://github.com/sunshangquan/logit-standardization-KD
代码已开源,欢迎star :)
0. 背景介绍
什么是知识蒸馏?2015年,Hinton[1]注意到深度学习模型变得越来越大,率先想到是否可以利用一个训练好的大模型(俗称Teacher、教师模型),教授一个小模型(俗称Student、学生模型)进行学习。
以常见的分类问题举例, 给定一个包含
个样本的图像分类数据集
是其中第
个样本图像,
是
对应的标签(数据集如果有
个类, 则
为 1 至
之间的一个整数, 代表图像属于第
个类), 学生网络
和教师网络
会读取一张输入图像
, 输出各自的logit:
我们用带有温度
的softmax函数, 就可以将logit转换为概率密度的形式:
其中
和
分别是学生和教师预测的概率密度, 其满足
和
。
随后,就可以用KL散度将学生的输出和教师的输出进行定量对比,以此作为损失函数对学生网络进行优化:整个过程可以看下面的图1a。
知识蒸馏经典工作解读:https://zhuanlan.zhihu.com/p/102038521
1. 动机
距离Hinton[1]2015年提出知识蒸馏已经过去了9年,温度这个超参数最早就被设定为教师、学生之间共享的,并且是对所有样本都全局不变的,而这样的设置并没有理论支持。已有工作中,CTKD[2]引入了对抗学习,针对不同难度的样本选择不一样的温度,但是它仍然让学生和教师共享温度;ATKD[3]则是引入一种锐利度指标,针对性地选择学生和教师之间的温度平衡。但是他们均没有讨论学生和教师的温度来自于哪里,也没有理论性地讨论它们是否可以全局不一致、学生教师之间不共享。因此针对这个问题,文章做出了三个贡献:
-
文章基于信息论中的熵最大化理论,推导了含有超参数温度的softmax函数表达式。基于推导过程,发现温度并没有明显的约束条件,即没有理论强制学生和教师在全局范围内共享温度(见1.1.1、1.1.2)
-
-
学生网络被迫输出和教师网络相当的logit(见1.2)
-
-
文章提出了logit标准化,可作为预处理辅助现有基于logit的知识蒸馏算法。
图1. logit标准差辅助的知识蒸馏和传统知识蒸馏的对比
1.1 超参数-温度的来源
这部分推导了带有超参数温度的softmax函数,如果只想看结论,可直接跳到1.3小节
1.1.1 教师网络的温度
基于 Edwin[5] 1957年提出的最大熵理论 (Maximum-Entropy Principle), 可以证明出分关任务的softmax函数是求解条件摘最大化问题的唯一解形式, 也就是说, 对于一个训绕好的教师网婉, 面对以下两个条件, 其概率密度
应该使得熵处于最大值, 两个条件分别为
-
变量向量
需要求和为 1 。也就是说, 它满足概率密度的形式
-
原始logit向量
的期望为正确的logit项。也就是说, 它需要预测正确
其数学表达为以下式子:
针对这个求解问题, 可以利用拉格朗日乘子法, 引入拉格朗日乘子
(条件1) 和
(条件 2), 将条件优化变为单一表达式:
对目标函数求导后取零, 即可得到优化问题的解的形式。于是我们对
求偏导, 得到:
对其取 0 ,我们得到
此处的
就变成了我们常见的softmax的分母,公式(7)就变为了我们常见的softmax函数。而
就是我们常见的温度变量, 它的值取 1 时, 就是分类任务中不带有温度的KL散度函数。
1.1.2 学生网络的温度
与上一小节类似,我们针对蒸馏任务,引入一个新的约束条件:
-
学生logit向量的期望需要等于教师logit的期望。也就是说,需要学生学习到教师的知识
加入这第三个约束的求解问题的数学表达为:
类似地引入拉格朗日乘子
(条件1)、
(条件2) 、
(条件3), 可以将条件优化变为单一表达式:
对
求偏导, 得到:
对其取 0 , 并且设
, 我们得到
其中为了简洁, 分母为
。公式(11)变为了我们常见的softmax函数。而
就是我们常见的温度变量, 它与
取等时, 就是蒸馏任务中最常见的学生、教师共享温度的情况。
讨论:
问题1: 学生和教师之间是否可以取不同的温度?
答: 如果对公式(5)和公式(8)分别对 和 求偏导,则其偏导表达式均会退回到对应的条件约束表达式,表达式恒成立,其结果与这三个变量 和 也因此无关,所以其取值井没有明显的约束形式。如果我们取 ,就是我们常见知识蒸馏中共享温度的情况。
问题2: 蒸馏过程中是否可以对不同样本取不同的温度?
答:与问题 1 类似,其取值井没有明显的约束形式,因此可以针对样本选择温度的取值。
1.2 共享温度的弊端(1/2)
上一小节讨论了“学生和教师网络是否可以针对样本选择不同的温度”的问题,但是我们还并不知道
是否有必要
选择不同的温度值,因此本节展示传统知识蒸馏共享温度带来的弊端。
首先我们将之前的softmax表达式统一得到一个一般形式,其表示为:
其中
和
分别为学生
和教师
的 softmax 表达式中的偏置项
和缩放项(
, 其中偏置项虽然可以通过分子分母相消, 但是其有稳定
it均值的作用 (之后2.2节中提到)。
对于一个蒸馏好的理想学生, 我们假设对于给定样本, 其损失函数(
散度)达到最小值。也就是说, 这个理想学生可以完美学习教师的概率, 那对于任意索引
, 都可以得到
那么对于任意一对索引
, 我们有:
将上面的式子按
从 1 到
求和, 然后除以
, 可以得到:
其中
和
分别是学生和教师logit的均值。然后我们对公式(14)按
从 1 到
求和, 我们可以得到:
其中,
是标准差函数。假设我们设定教师和学生的温度相同, 即
, 那么公式(14)就会变为:
由此可以看出经典知识蒸馏问题中共享温度的设定, 最终会强制学生和教师logit之间存在一个固定的差
, 但是考虑到学生网络和教师网络之间的能力差距, 学生很难生成与教师 logit 的均值相当的logit
(见图2横轴)。
图2. 不同尺寸网络的logit均值和标准差的双变量散点图
图2展示了不同尺寸网络输出的logit均值和标准差,可以看出尺寸较大的网络均值更接近于0,标准差也越小,也就是logit更紧凑。由此规律可以看出,较小的学生网络确实难以输出和较大的教师网络相当的logit范围。
1.3 小节
我们从本节可以得到以下结论:
-
学生、教师网络的温度没有明显的约束条件,不必一定全局共享,可以人为指定
-
在温度共享的情况下,学生、教师网络之间有一个logit范围的强制性匹配
-
基于学生、教师网络的能力差距, 上述强制性匹配很可能限制学生网络的蒸馏效果
2. 提出方法:Logit标准化
2.1 算法
为了打破上述的强制性匹配, 文章基于公式(14)的形式, 提出了logit标准化, 即把
、
、
和
代入softmax函数:
其中
函数就是一种加权
-score标准化函数, 其表达形式如算法1所示, 通过引入一个基础温度
来控制标准化后的logit值域(见2.2优势第四条)。而完整的logit标准化知识烝馏算法则如算法2所示。
算法1:Z-score函数算法;算法2:logit标准化的知识蒸馏算法
2.2 优势
这样
-score标准化后的logit,
, 有至少四个好处(所有证明可见文章补充材料):
之前的工作
常常有学生/教师logit的均值为 0 的假设, 但是如图 2 所示, 其几乎是不可能真实实现的,而基于提出的logit标准化函数,logit均值会自动变为 0 。
这个性质也是
-score自带的性质, 证明简单。这条性质使学生、教师logit被投影到同一范围的类-高斯分布内, 而由于其投影过程是多对一的, 意味着其反向过程是不确定的, 确保了学生的原始logit可以不受“强制性匹配”的副作用影响。
其定义为给定一串索引序列
, 如果其可以将原始logit进行由小到大的排序, 即
, 那么其也能够将变 换后 logiti 进行相对应的 排序, 即
。由于
-score属于线性变换函数, 这条性质自动满足。这条性质确保了学生能够学习到教师网络logit必要的内在关系。