-Distilling the Knowledge in a Neural Network
Geoffrey Hinton
∗
†
Google Inc. Mountain View [email protected]
Oriol Vinyals† Google Inc. Mountain View [email protected]
Jeff Dean Google Inc. Mountain
[email protected]
摘要
几乎任何机器学习算法提高性能的一种简单方式是在相同数据上训练许多不同模型,然后对它们进行平均预测
[3]
。不幸的是,使用整个模型集合进行预测很麻烦,可能会因为计算成本过高而无法部署到大量用户中,特别是如果单个模型是庞大的神经网络。
Caruana
和他的合作者
[1]
已经证明可以将模型集合中的知识压缩到一个单一模型中,这样更容易部署,我们使用一种不同的压缩技术进一步发展这种方法。我们在
MNIST
数据集上取得了一些令人惊讶的结果,
我们展示了将模型集合的知识提炼成一个单一模型可以显著提高一个被广泛使用的商业系统的声学模型
。
我们还引入了一种新类型的集合,由一个或多个全模型和许多专家模型组成,这些专家模型学会区分全模型混淆的细粒度类别。与专家混合不同,这些专家模型可以快速并行训练。
目录
1
简介
2
蒸馏
3
初步在MNIST
上的实验
4
语音识别实验
4.1
结果
5
在非常大的数据集上训练专家集合
5.1
JFT
数据集
5.2
专业模型
5.3
分配专家类别
5.4
专家集成进行推理
5.5
结果
6
软目标作为正则化项(
SOFT TARGETS AS REGULARIZERS
)
7
专家混合模型的相关性
8
讨论
致谢
参考文献
1
简介
许多昆虫都有一种儿童形态,它们能够从环境中提取能量和营养,还有一种完全不同的成年形态,它们专门用于旅行和繁殖
。
在大规模机器学习中,尽管训练阶段和部署阶段的要求完全不同,但我们通常使用非常相似的模型
:对于语音和物体识别等任务,训练阶段必须从非常大的、高度冗余的数据集中提取结构,但不需要实时操作,并且可以使用大量计算资源。然而,部署给大量用户的情况下,响应时间和计算资源要求会更严格。
昆虫的类比表明,如果这能更容易从数据中提取结构,我们应该愿意训练非常笨重的模型
。这种笨重的模型可以是一组分别训练的模型的集合,也可以是一个使用了非常强的正则化方法(如
dropout
)训练的非常大的模型。一旦训练完成笨重的模型,我们可以使用一种称为
“
蒸馏
”
的不同类型的训练,将知识从笨重的模型转移到更适合部署的小模型上。这种策略的一个版本已经由
Rich Caruana
和他的合作者们开创了。在他们的重要论文中,他们有力地证明了大量模型获得的知识可以转移到单个小模型上。
一个可能阻止进一步调查这种非常有前途方法的概念障碍是,我们倾向于将训练模型中的知识与学到的参数值进行等同,这使我们很难看到如何改变模型的形式但保持相同的知识。对知识的更抽象视角,使其摆脱任何特定的实例化,是它是一个学习到的。
将输入向量映射到输出向量
。对于那些学会区分大量类别的繁琐模型,常规的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练模型会为所有不正确的答案分配概率,即使这些概率非常小,其中一些比其他的大得多。不正确答案的相对概率告诉我们有关繁琐模型如何倾向于泛化的重要信息。例如,一张宝马车的图像可能被误认为是垃圾车的几率非常小,但这个错误仍然比将其误认为是一根胡萝卜的概率大许多倍。
通常认为,用于训练的目标函数应尽可能反映用户的真实目标
。尽管如此,模型通常被训练以在训练数据上优化性能,而真正的目标是良好地泛化到新数据。显然,最好是训练模型以良好泛化,但这需要关于正确泛化方式的信息,而这种信息通常不可用。然而,当我们将大型模型中的知识提炼成小型模型时,我们可以训练小型模型以与大型模型相同的方式泛化。如果臃肿的模型能很好地泛化,例如它是多个不同模型的平均值,那么以相同方式训练以泛化的小型模型通常在测试数据上要比以正常方式在同一训练集上训练的小型模型更好。
将繁琐模型的泛化能力转移到小模型的一个明显方法是使用繁琐模型产生的类别概率作为训练小模型的
“
软目标
”
。
对于这种转移阶段,我们可以使用相同的训练集或单独的
“
转移
”
集。当繁琐模型是由一组简单模型组成的大型集合时,我们可以使用它们各自预测分布的算术或几何平均作为软目标。当软目标具有很高的熵时,在每个训练案例中它们提供的信息比硬目标多得多,并且在训练案例之间的梯度变化要小得多,因此小模型通常可以在比原始繁琐模型少得多的数据上进行训练,并且使用更高的学习速率。
对于像
MNIST
这样在其中笨重模型几乎总能以极高的置信度产生正确答案的任务,关于学习函数的大部分信息都存在于软目标中非常小的概率比例中。例如,对于一个数字
2
的版本来说,被认为是数字
3
的概率可能为
10^-6
,被认为是数字
7
的概率可能为
10^-9
,而对于另一个版本来说可能正好相反。这是有价值的信息,它定义了数据上的丰富相似结构(即告诉我们哪些数字
2
看起来像数字
3
,哪些看起来像数字
7
),但在转移阶段中对交叉熵损失函数的影响非常小,因为概率接近于零。
Caruana
和他的合作者通过使用
logits
(最终软最大值函数的输入)而不是由软最大值函数产生的概率作为学习小模型的目标,并最小化笨重模型产生的
logits
与小模型产生的
logits
之间的平方差,绕过了这个问题。我们更一般的解决方案,称为
"
蒸馏
"
,是将最终软最大值函数的温度提高,直到笨重模型产生一组合适的软目标。然后在训练小模型时使用相同的高温度来匹配这些软目标。我们后面将展示笨重模型
logits
的匹配实际上是蒸馏的一个特例。
用于训练小模型的传递集可以完全由未标记的数据
[1]
组成,也可以使用原始训练集
。我们发现,使用原始训练集效果很好,特别是如果我们在目标函数中添加一个小项,鼓励小模型预测真实目标,并匹配臃肿模型提供的软目标。通常,小模型无法完全匹配软目标,并且朝着正确答案错误地前进会发现是有帮助的。
2
蒸馏
神经网络通常通过使用
“softmax”
输出层来产生类别概率,该输出层将计算出的每个类别的逻辑值
zi
转换为概率
qi
,通过将
zi
与其他逻辑值进行比较的方式。
T
是一个通常设置为
1
的温度。使用一个更高的数值可以产生一个更软性的类别概率分布。
在最简单的蒸馏形式中,通过将蒸馏模型训练于一个转移集,并使用高温下
softmax
函数在每个转移集案例中生成的软目标分布来传递知识给蒸馏模型。
在训练蒸馏模型时,使用相同的高温,但在训练完成后,蒸馏模型使用温度为
1
。
当所有或部分转移集的正确标签已知时,通过训练精炼模型以产生正确标签,可以显著改善此方法。一种方法是使用正确标签来修改软目标,但我们发现一种更好的方法是简单地使用两种不同目标函数的加权平均。第一个目标函数是与软目标的交叉熵,这个交叉熵是使用从笨拙模型生成软目标时
softmax
中的高温相同的温度计算的。第二个目标函数是与正确标签的交叉熵,这是使用与
softmax
中精炼模型相同
logits
但温度为
1
计算的。我们发现,通常通过在第二个目标函数上使用较低的权重获得最佳结果。由于由软目标产生的梯度的大小按比例缩放为
1/T^2
,因此在同时使用硬目标和软目标时,重要的是将它们乘以
T^2
。这可确保如果在尝试元参数时更改蒸馏的温度,则硬目标和软目标的相对贡献大致保持不变。
2.1
匹配的
logits
是蒸馏的特殊情况
转移集中的每个案例都会对蒸馏模型的每个
logit zi
贡献一个交叉熵梯度
dC/dzi
。如果繁琐的模型具有产生软目标概率
pi
的
logits vi
,而转移训练是以温度
T
进行的话,该梯度由以下公式给出:
如果温度高于
logits
的幅度,我们可以近似计算:
如果我们现在假设
logits
已经针对每个转移案例单独进行了零均值处理,以便
j zj = Pj vj = 0
,方程
3
简化为:
在高温极限下,蒸馏等价于最小化
1/2(zi −vi)2
,前提是对每个转移情况的
logits
进行零均值处理。在较低温度下,蒸馏对比平均值更低的
logits
匹配关注度明显降低。这可能是有利的,因为这些
logits
几乎完全不受用于训练庞大模型的成本函数的限制,因此可能非常嘈杂。另一方面,这些非常低的
logits
可能传达了庞大模型所获知的有用信息。哪一种效应占主导地位是一个经验问题。我们展示,当精炼模型太小以捕捉庞大模型中的所有知识时,中间温度效果最好,这强烈暗示忽略较大的负
logits
可能是有帮助的。
3
初步在
MNIST
上的实验
为了看到蒸馏的效果如何,我们训练了一个单独的大型神经网络,该网络有两个隐藏层,每个隐藏层有
1200
个修正线性隐藏单元,用于所有的
60,000
个训练案例。该网络通过使用
dropout
和权重约束进行强大的正则化,如
[5]
中所描述的。
Dropout
可以看作是训练具有共享权重的指数级模型集合的一种方法。此外,输入图像在任何方向上最多抖动了两个像素。该网络实现了
67
个测试错误,而一个更小的网络,有两个隐藏层,每个隐藏层有
800
个修正线性隐藏单元,没有正则化,实现了
146
个错误。但如果这个较小的网络仅通过添加与大网络产生的软目标相匹配的附加任务来进行正则化,温度为
20
,它实现了
74
个测试错误。这表明,软目标可以将大量知识传输给蒸馏模型,包括从翻译训练数据中学到的泛化知识,尽管转移集中不包含任何翻译。
当蒸馏网络的两个隐藏层中每层有
300
个单位或更多时,所有大于
8
的温度得到的结果都相当相似。但当这被彻底减少到每层
30
个单位时,范围在
2.5
到
4
之间的温度比较高或较低的温度表现明显更好。
然后我们尝试在转移集中省略所有数字
3
的示例
。因此,从精炼模型的角度来看,
3
是一个它从未见过的神话数字。尽管如此,精炼模型仅在测试中产生
206
个错误,其中
133
个是在测试集中的
1010
个数字
3
上。大多数错误是由于对
3
类的学习偏差太低。如果将该偏差增加
3.5
(这会优化测试集的整体性能),那么精炼模型会产生
109
个错误,其中
14
个是在
3
个数字上。因此,在正确的偏差下,尽管在训练过程中从未见过数字
3
,精炼模型正确识别了
98.6%
的测试数字
3
。如果转移集仅包含来自训练集的
7
和
8
,那么精炼模型会产生
47.3%
的测试错误,但是当将
7
和
8
的偏差减小
7.6
以优化测试性能时,这一数字降至
13.2%
测试错误。
4
语音识别实验
在本节中,我们研究了集成用于自动语音识别(
ASR
)中的深度神经网络(
DNN
)声学模型的影响。我们表明,我们在本文中提出的蒸馏策略实现了将一组模型蒸馏成一个单一模型的预期效果,该模型比直接从相同训练数据中学习的相同大小的模型表现显著更好。
目前,最先进的自动语音识别(
ASR
)系统使用深度神经网络(
DNN
)将从波形中提取的(短暂的)时间上下文特征映射到隐马尔可夫模型(
HMM
)的离散状态的概率分布
[4]
。具体来说,
DNN
在每个时间点上产生对三音素状态群的概率分布,然后解码器找到一条穿过
HMM
状态的路径,这条路径在使用高概率状态和生成符合语言模型的转录之间达到最佳平衡。
尽管可能(并且是可取的)通过训练
DNN
,使解码器(以及语言模型)考虑到通过对所有可能的路径进行边际化来训练它,但通常训练
DNN
以通过(局部地)最小化网络所做预测与每个观测的状态的地面真实序列的强制对齐给定标签之间的交叉熵来执行逐帧分类:
θ
是我们的声学模型
P
的参数,该模型将时间
t
处的声学观测
st
映射到一个概率
P(ht|st;θ′)
,表示
“
正确
”
的
HMM
状态
ht
,这由与正确单词序列的强制对齐确定。该模型采用分布式随机梯度下降方法进行训练。
我们采用了一个具有
8
个隐藏层的架构,每个隐藏层包含
2560
个修正线性单元,以及一个最终的
softmax
层,具有
14,000
个标签(
HMM
目标
ht
)。
输入是
26
帧的
40
个
Mel-scaled
滤波器组系数,每帧间隔
10
毫秒,我们预测第
21
帧的
HMM
状态。总参数数量约为
85M
。这是
Android
语音搜索使用的声学模型的略旧版本,应该被视为一个非常强大的基线。为了训练
DNN
声学模型,我们使用了大约
2000
小时的英语口语数据,产生了大约
700M
个训练样本。该系统在我们的开发集上实现了
58.9
%的帧准确率和
10.9
%的词错误率(
WER
)。
表
1
:分类准确率和
WER
显示,精简的单一模型的表现与用于创建软目标的
10
个模型的平均预测相当。
4.1
结果
我们训练了
10
个单独的模型来预测
P(ht|st;θ)
,使用完全相同的架构和训练程序作为基准。 这些模型是随机初始化的,具有不同的初始参数值,我们发现这样可以在训练的模型中产生足够的多样性,使得整体模型的平均预测能够明显优于个别模型。 我们已经尝试通过改变每个模型看到的数据集来为模型增加多样性,但我们发现这并没有显著改变我们的结果,所以我们选择了更简单的方法。 对于蒸馏,我们尝试了温度为
[1,2,5,10]
,并在硬目标的交叉熵上使用了相对权重
0.5
,其中粗体表示表
1
中使用的最佳值。
表
1
显示,实际上,我们的蒸馏方法能够从训练集中提取比仅使用硬标签训练单个模型更多的有用信息。使用
10
个模型的集成模型在帧分类准确度方面取得的
80%
以上的改进传递给了与我们在
MNIST
上的初步实验中观察到的改进类似的蒸馏模型。由于目标函数不匹配,集成模型在
23K
词测试集上对
WER
的最终目标改进较小,但同样,集成模型对
WER
的改进也传递给了蒸馏模型。
最近,我们了解到一个相关工作,通过匹配已经训练好的大型模型的类别概率来学习一个小型声学模型
[8]
。然而,他们在温度为
1
时使用大规模无标签数据集进行蒸馏,他们最好的蒸馏模型只能将小型模型的错误率降低
28%
,这个百分比是大型模型和小型模型在使用硬标签训练时错误率之间的差距。
5
在非常大的数据集上训练专家集合
训练一个模型集合是利用并行计算的非常简单的方法,通常的反对意见是在测试时模型集合需要太多计算量,可以通过使用蒸馏来解决。然而,对于模型集合还有另一个重要的反对意见:如果个体模型是大型神经网络,数据集非常大,那么在训练时需要的计算量是过多的,即使容易并行化。
在本节中,我们提供了一个这样的数据集的示例,并展示了如何学习专家模型,每个模型都专注于不同的可混淆类别子集,可以减少学习集合所需的总计算量。专注于进行细粒度区分的专家的主要问题是它们很容易出现过拟合,我们将介绍如何使用软目标来防止过拟合。
5.1
JFT
数据集
JFT
是
Google
内部的一个数据集,包含了
1
亿张带有
1.5
万个标签的标记图片。
在我们进行这项工作时,
Google
对
JFT
的基准模型是一个深度卷积神经网络,使用了大量核心进行了约六个月的异步随机梯度下降训练。这种训练使用了两种类型的并行处理。
首先,神经网络的许多副本在不同的核心集上运行,处理来自训练集的不同的小批量数据。每个副本计算其当前小批量数据的平均梯度,并将该梯度发送到一个分片参数服务器,该服务器发送回参数的新值。这些新值反映了参数服务器自上次向副本发送参数以来接收的所有梯度。其次,将每个副本分布在多个核心上,通过将不同的神经元子集放在每个核心上。集成训练是另一种可以实现的并行处理类型。
表
2
:由我们的协方差矩阵聚类算法计算得出的示例类别。
相对于其他两种类型,只有在有更多核心可用的情况下,才能更好地运行。等待数年来训练一组模型并不可行,因此我们需要一种更快的方法来改进基线模型。
5.2
专业模型
当类的数量非常大时,让臃肿的模型变成一个集成模型是有意义的,其中包含一个在所有数据上训练的通用模型和许多
“
专家
”
模型,每个模型都是在高度富含来自一个容易混淆的类别子集(例如不同类型的蘑菇)的示例的数据上进行训练的。这种类型的专家的
softmax
可以通过将其不关心的所有类别合并成一个垃圾箱类别而变得更小。
为了减少过拟合并分享学习较低级特征探测器的工作,每个专家模型都以通用模型的权重初始化。然后,通过训练专家模型,将这些权重稍微修改,其中一半示例来自其特殊子集,另一半来自训练集的剩余部分的随机采样。训练后,我们可以通过将垃圾箱类别的逻辑增加专家类别过采样的比例的对数来校正对训练集的偏倚。
5.3
分配专家类别
为了为专家们推导对象类别的分组,我们决定将重点放在我们的整个网络经常混淆的类别上。虽然我们可以计算混淆矩阵并将其用作查找这些群集的一种方法,但我们选择了一种更简单的方法,不需要真实标签来构建这些群集。
特别是,我们对我们的综合模型的预测的协方差矩阵应用了聚类算法,以便一个经常一起预测的类别集合
Sm
将被用作我们的一个专家模型
m
的目标。我们对协方差矩阵的列应用了在线版本的
K-means
算法,并得到了合理的聚类结果(如表
2
所示)。我们尝试了几种产生类似结果的聚类算法。
5.4
专家集成进行推理
在研究专家模型蒸馏的情况之前,我们想要看看包含专家的集合模型表现如何。除了专家模型之外,我们总是有一个通用模型,以便处理我们没有专家的类别,并决定使用哪些专家模型。给定输入图像
x
,我们通过两个步骤进行
top-one
分类。
步骤
1
:对于每个测试案例,我们根据通用模型找到概率最高的
n
个类别。将这个类别集合称为
k
。在我们的实验中,我们使用
n = 1
。
步骤
2
:然后我们采取所有专家模型
m
,其可混淆类的特殊子集
Sm
与
k
有非空交集,并将其称为专家活动集
Ak
(注意该集合可能为空)。然后找到所有类别的概率分布
q
,使其最小化:
KL
表示
KL
散度,而
pm
和
pg
表示专家模型或全模型的概率分布。
pm
是对
m
个专家类别及一个垃圾箱类别的分布,因此在计算其与全局
q
分布的
KL
散度时,我们要对
m
的垃圾箱中的所有类别的概率进行求和。
表
3