点击上方
“
小白学视觉
”,选择加"
星标
"或“
置顶
”
重磅干货,第一时间送达![](http://mmbiz.qpic.cn/mmbiz_jpg/ow6przZuPIENb0m5iawutIf90N2Ub3dcPuP2KXHJvaR1Fv2FnicTuOy3KcHuIEJbd9lUyOibeXqW8tEhoJGL98qOw/640?wx_fmt=other&wxfrom=5&wx_lazy=1&wx_co=1&tp=webp)
知识蒸馏有两大类:一类是
「logits蒸馏」
,另一类是
「特征蒸馏」
。
「logits蒸馏」
指的是在softmax时使用较高的温度系数,提升负标签的信息,然后使用Student和Teacher在高温softmax下logits的KL散度作为loss。
「中间特征蒸馏」
就是强迫Student去学习Teacher某些中间层的特征,直接匹配中间的特征或学习特征之间的转换关系。例如,在特征No.1和No.2中间,知识可以表示为如何模做两者中间的转化,可以用一个矩阵让学习者产生这个矩阵,学习者和转化之间的学习关系。
这篇文章汇总了常用的知识蒸馏的论文和代码,方便后续的学习和研究。
1、Logits
论文链接:https://proceedings.neurips.cc/paper/2014/file/ea8fcd92d59581717e06eb187f10666d-Paper.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class Logits(nn.Module):
'''
Do Deep Nets Really Need to be Deep?
http://papers.nips.cc/paper/5484-do-deep-nets-really-need-to-be-deep.pdf
'''
def __init__(self):
super(Logits, self).__init__()
def forward(self, out_s, out_t):
loss = F.mse_loss(out_s, out_t)
return loss
2、ST
论文链接:https://arxiv.org/pdf/1503.02531.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftTarget(nn.Module):
'''
Distilling the Knowledge in a Neural Network
https://arxiv.org/pdf/1503.02531.pdf
'''
def __init__(self, T):
super(SoftTarget, self).__init__()
self.T = T
def forward(self, out_s, out_t):
loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
F.softmax(out_t/self.T, dim=1),
reduction='batchmean') * self.T * self.T
return loss
3、AT
论文链接:https://arxiv.org/pdf/1612.03928.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
AT with sum of absolute values with power p
'''
class AT(nn.Module):
'''
Paying More Attention to Attention: Improving the Performance of Convolutional
Neural Netkworks wia Attention Transfer
https://arxiv.org/pdf/1612.03928.pdf
'''
def __init__(self, p):
super(AT, self).__init__()
self.p = p
def forward(self, fm_s, fm_t)
:
loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))
return loss
def attention_map(self, fm, eps=1e-6):
am = torch.pow(torch.abs(fm), self.p)
am = torch.sum(am, dim=1, keepdim=True)
norm = torch.norm(am, dim=(2,3), keepdim=True)
am = torch.div(am, norm+eps)
return am
4、Fitnet
论文链接:https://arxiv.org/pdf/1412.6550.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class Hint(nn.Module):
'''
FitNets: Hints for Thin Deep Nets
https://arxiv.org/pdf/1412.6550.pdf
'''
def __init__(self):
super(Hint, self).__init__()
def forward(self, fm_s, fm_t):
loss = F.mse_loss(fm_s, fm_t)
return loss
5、NST
论文链接:https://arxiv.org/pdf/1707.0121
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
NST with Polynomial Kernel, where d=2 and c=0
'''
class NST(nn.Module):
'''
Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
https://arxiv.org/pdf/1707.01219.pdf
'''
def __init__(self):
super(NST, self).__init__()
def forward(self, fm_s, fm_t):
fm_s = fm_s.view(fm_s.size(0), fm_s.size(1), -1)
fm_s = F.normalize(fm_s, dim=2)
fm_t = fm_t.view(fm_t.size(0), fm_t.size(1), -1)
fm_t = F.normalize(fm_t, dim=2)
loss = self.poly_kernel(fm_t, fm_t).mean() \
+ self.poly_kernel(fm_s, fm_s).mean() \
- 2 * self.poly_kernel(fm_s, fm_t).mean()
return loss
def poly_kernel(self, fm1, fm2):
fm1 = fm1.unsqueeze(1)
fm2 = fm2.unsqueeze(2)
out = (fm1 * fm2).sum(-1).pow(2)
return out
6、PKT
论文链接:http://openaccess.thecvf.com/content_ECCV_2018/papers/Nikolaos_Passalis_Learning_Deep_Representations_ECCV_2018_paper.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
Adopted from https://github.com/passalis/probabilistic_kt/blob/master/nn/pkt.py
'''
class PKTCosSim(nn.Module):
'''
Learning Deep Representations with Probabilistic Knowledge Transfer
http://openaccess.thecvf.com/content_ECCV_2018/papers/Nikolaos_Passalis_Learning_Deep_Representations_ECCV_2018_paper.pdf
'''
def __init__(self):
super(PKTCosSim, self).__init__()
def forward(self, feat_s, feat_t, eps=1e-6):
# Normalize each vector by its norm
feat_s_norm = torch.sqrt(torch.sum(feat_s ** 2, dim=1, keepdim=True))
feat_s = feat_s / (feat_s_norm + eps)
feat_s[feat_s != feat_s] = 0
feat_t_norm = torch.sqrt(torch.sum(feat_t ** 2, dim=1, keepdim=True))
feat_t = feat_t / (feat_t_norm + eps)
feat_t[feat_t != feat_t] = 0
# Calculate the cosine similarity
feat_s_cos_sim = torch.mm(feat_s, feat_s.transpose(0, 1))
feat_t_cos_sim = torch.mm(feat_t, feat_t.transpose(0, 1))
# Scale cosine similarity to [0,1]
feat_s_cos_sim = (feat_s_cos_sim + 1.0) / 2.0
feat_t_cos_sim = (feat_t_cos_sim + 1.0) / 2.0
# Transform them into probabilities
feat_s_cond_prob = feat_s_cos_sim / torch.sum(feat_s_cos_sim, dim=1, keepdim=True)
feat_t_cond_prob = feat_t_cos_sim / torch.sum(feat_t_cos_sim, dim=1, keepdim=True)
# Calculate the KL-divergence
loss = torch.mean(feat_t_cond_prob * torch.log((feat_t_cond_prob + eps) / (feat_s_cond_prob + eps)))
return loss
7、FSP
论文链接:http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
代码:
from __future__ import absolute_import
from
__future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class FSP(nn.Module):
'''
A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
'''
def __init__(self):
super(FSP, self).__init__()
def forward(self, fm_s1, fm_s2, fm_t1, fm_t2):
loss = F.mse_loss(self.fsp_matrix(fm_s1,fm_s2), self.fsp_matrix(fm_t1,fm_t2))
return loss
def fsp_matrix(self, fm1, fm2):
if fm1.size(2) > fm2.size(2):
fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
fm1 = fm1.view(fm1.size(0), fm1.size(1), -1)
fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1,2)
fsp = torch.bmm(fm1, fm2) / fm1.size(2)
return fsp
8、FT
论文链接:http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class FT(nn.Module):
'''
araphrasing Complex Network: Network Compression via Factor Transfer
http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf
'''
def __init__(self):
super(FT, self).__init__()
def forward(self, factor_s, factor_t):
loss = F.l1_loss(self.normalize(factor_s), self.normalize(factor_t))
return loss
def normalize(self, factor):
norm_factor = F.normalize(factor.view(factor.size(0),-1))
return norm_factor
9、RKD
论文链接:https://arxiv.org/pdf/1904.05068.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
From https://github.com/lenscloth/RKD/blob/master/metric/loss.py
'''
class RKD(nn.Module):
'''
Relational Knowledge Distillation
https://arxiv.org/pdf/1904.05068.pdf
'''
def __init__(self, w_dist, w_angle):
super(RKD, self).__init__()
self.w_dist = w_dist
self.w_angle = w_angle
def forward(self, feat_s, feat_t):
loss = self.w_dist * self.rkd_dist(feat_s, feat_t) + \
self.w_angle * self.rkd_angle(feat_s, feat_t)
return loss
def rkd_dist(self, feat_s, feat_t):
feat_t_dist = self.pdist(feat_t, squared=False)
mean_feat_t_dist = feat_t_dist[feat_t_dist>0].mean()
feat_t_dist = feat_t_dist / mean_feat_t_dist
feat_s_dist = self.pdist(feat_s, squared=False)
mean_feat_s_dist = feat_s_dist[feat_s_dist>0].mean()
feat_s_dist = feat_s_dist / mean_feat_s_dist
loss = F.smooth_l1_loss(feat_s_dist, feat_t_dist)
return loss
def rkd_angle(self, feat_s, feat_t):
# N x C --> N x N x C
feat_t_vd = (feat_t.unsqueeze(0) - feat_t.unsqueeze(1))
norm_feat_t_vd = F.normalize(feat_t_vd, p=2, dim=2)
feat_t_angle = torch.bmm(norm_feat_t_vd, norm_feat_t_vd.transpose(1, 2)).view(-1)
feat_s_vd = (feat_s.unsqueeze(0) - feat_s.unsqueeze(1))
norm_feat_s_vd = F.normalize(feat_s_vd, p=2, dim=2)
feat_s_angle = torch.bmm(norm_feat_s_vd, norm_feat_s_vd.transpose(1, 2)).view(-1)
loss = F.smooth_l1_loss(feat_s_angle, feat_t_angle)
return loss
def pdist(self, feat, squared=False, eps=1e-12):
feat_square = feat.pow(2).sum(dim=1)
feat_prod = torch.mm(feat, feat.t())
feat_dist = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)
if not squared:
feat_dist = feat_dist.sqrt()
feat_dist = feat_dist.clone()
feat_dist[range(len(feat)), range(len(feat))] = 0
return feat_dist
10、AB
论文链接:https://arxiv.org/pdf/1811.03233.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class AB(nn.Module)
:
'''
Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
https://arxiv.org/pdf/1811.03233.pdf
'''
def __init__(self, margin):
super(AB, self).__init__()
self.margin = margin
def forward(self, fm_s, fm_t):
# fm befor activation
loss = ((fm_s + self.margin).pow(2) * ((fm_s > -self.margin) & (fm_t <= 0)).float() +
(fm_s - self.margin).pow(2) * ((fm_s <= self.margin) & (fm_t > 0)).float())
loss = loss.mean()
return loss
11、SP
论文链接:https://arxiv.org/pdf/1907.09682.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class SP(nn.Module):
'''
Similarity-Preserving Knowledge Distillation
https://arxiv.org/pdf/1907.09682.pdf
'''
def __init__(self):
super(SP, self).__init__()
def forward(self, fm_s, fm_t):
fm_s = fm_s.view(fm_s.size(0), -1)
G_s = torch.mm(fm_s, fm_s.t())
norm_G_s = F.normalize(G_s, p=2, dim=1)
fm_t = fm_t.view(fm_t.size(0), -1)
G_t = torch.mm(fm_t, fm_t.t())
norm_G_t = F.normalize(G_t, p=2, dim=1)
loss = F.mse_loss(norm_G_s, norm_G_t)
return loss
12、Sobolev
论文链接:https://arxiv.org/pdf/1706.04859.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
class Sobolev(nn.Module):
'''
Sobolev Training for Neural Networks
https://arxiv.org/pdf/1706.04859.pdf
Knowledge Transfer with Jacobian Matching
http://de.arxiv.org/pdf/1803.00443
'''
def __init__(self):
super(Sobolev, self).__init__()
def forward(self, out_s, out_t, img, target):
target_out_s = torch.gather(out_s, 1, target.view(-1, 1))
grad_s = grad(outputs=target_out_s, inputs=img,
grad_outputs=torch.ones_like(target_out_s),
create_graph=True, retain_graph=True, only_inputs=True)[0]
norm_grad_s = F.normalize(grad_s.view(grad_s.size(0), -1), p=2, dim=1)
target_out_t = torch.gather(out_t, 1, target.view(-1, 1))
grad_t = grad(outputs=target_out_t, inputs=img,
grad_outputs=torch.ones_like(target_out_t),
create_graph=True, retain_graph=True, only_inputs=True)[0]
norm_grad_t = F.normalize(grad_t.view(grad_t.size(0), -1), p=2, dim=1)
loss = F.mse_loss(norm_grad_s, norm_grad_t.detach())
return loss
13、BSS
论文链接:https://arxiv.org/pdf/1805.05532.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.gradcheck import zero_gradients
'''
Modified by https://github.com/bhheo/BSS_distillation
'''
def reduce_sum(x, keepdim=True):
for d in reversed(range(1, x.dim())):
x = x.sum(d, keepdim=keepdim)
return x
def l2_norm(x, keepdim=True):
norm = reduce_sum(x*x, keepdim=keepdim)
return norm.sqrt()
class BSS(nn.Module):
'''
Knowledge Distillation with Adversarial Samples Supporting Decision Boundary
https://arxiv.org/pdf/1805.05532.pdf
'''
def __init__(self, T):
super(BSS, self).__init__()
self.T = T
def forward(self, attacked_out_s, attacked_out_t):
loss = F.kl_div(F.log_softmax(attacked_out_s/self.T, dim=1),
F.softmax(attacked_out_t/self.T, dim=1),
reduction='batchmean') #* self.T * self.T
return loss
class BSSAttacker():
def __init__(self, step_alpha, num_steps, eps=1e-4):
self.step_alpha = step_alpha
self.num_steps = num_steps
self.eps = eps
def attack(self, model, img, target, attack_class):
img = img.detach().requires_grad_(True)
step = 0
while step zero_gradients(img)
_, _, _, _, _, output = model(img)
score = F.softmax(output, dim=1
)
score_target = score.gather(1, target.unsqueeze(1))
score_attack_class = score.gather(1, attack_class.unsqueeze(1))
loss = (score_attack_class - score_target).sum()
loss.backward()
step_alpha = self.step_alpha * (target == output.max(1)[1]).float()
step_alpha = step_alpha.unsqueeze(1).unsqueeze(1).unsqueeze(1)
if step_alpha.sum() == 0:
break
pert = (score_target - score_attack_class).unsqueeze(1).unsqueeze(1)
norm_pert = step_alpha * (pert + self.eps) * img.grad / l2_norm(img.grad)
step_adv = img + norm_pert
step_adv = torch.clamp(step_adv, -2.5, 2.5)
img.data = step_adv.data
step += 1
return img
14、CC
论文链接:http://openaccess.thecvf.com/content_ICCV_2019/papers/Peng_Correlation_Congruence_for_Knowledge_Distillation_ICCV_2019_paper.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
'''
CC with P-order Taylor Expansion of Gaussian RBF kernel
'''
class CC(nn.Module):
'''
Correlation Congruence for Knowledge Distillation
http://openaccess.thecvf.com/content_ICCV_2019/papers/
Peng_Correlation_Congruence_for_Knowledge_Distillation_ICCV_2019_paper.pdf
'''
def __init__(self, gamma, P_order):
super(CC, self).__init__()
self.gamma = gamma
self.P_order = P_order
def forward(self, feat_s, feat_t):
corr_mat_s = self.get_correlation_matrix(feat_s)
corr_mat_t = self.get_correlation_matrix(feat_t)
loss = F.mse_loss(corr_mat_s, corr_mat_t)
return loss
def get_correlation_matrix(self, feat):
feat = F.normalize(feat, p=2, dim=-1)
sim_mat = torch.matmul(feat, feat.t())
corr_mat = torch.zeros_like(sim_mat)
for p in range(self.P_order+1):
corr_mat += math.exp(-2*self.gamma) * (2*self.gamma)**p / \
math.factorial(p) * torch.pow(sim_mat, p)
return corr_mat
15、LwM
论文链接:https://arxiv.org/pdf/1811.08051.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
'''
LwM is originally an incremental learning method with
classification/distillation/attention distillation losses.
Here, LwM is only defined as the Grad-CAM based attention distillation.
'''
class LwM(nn.Module):
'''
Learning without Memorizing
https://arxiv.org/pdf/1811.08051.pdf
'''
def __init__(self):
super(LwM, self).__init__()
def forward(self, out_s, fm_s, out_t, fm_t, target):
target_out_t = torch.gather(out_t, 1, target.view(-1, 1))
grad_fm_t = grad(outputs=target_out_t, inputs=fm_t,
grad_outputs=torch.ones_like(target_out_t),
create_graph=True, retain_graph=True, only_inputs=True)[0]
weights_t = F.adaptive_avg_pool2d(grad_fm_t, 1)
cam_t = torch.sum(torch.mul(weights_t, grad_fm_t), dim=1, keepdim=True)
cam_t = F.relu(cam_t)
cam_t = cam_t.view(cam_t.size(0), -1)
norm_cam_t = F.normalize(cam_t, p=2, dim=1)
target_out_s = torch.gather(out_s, 1, target.view(-1, 1))
grad_fm_s = grad(outputs=target_out_s, inputs=fm_s,
grad_outputs=torch.ones_like(target_out_s),
create_graph=True, retain_graph=True, only_inputs=True)[0]
weights_s = F.adaptive_avg_pool2d(grad_fm_s, 1)
cam_s = torch.sum(torch.mul(weights_s, grad_fm_s), dim=1, keepdim=True)
cam_s = F.relu(cam_s)
cam_s = cam_s.view(cam_s.size(0), -1)
norm_cam_s = F.normalize(cam_s, p=2, dim=1)
loss = F.l1_loss(norm_cam_s, norm_cam_t.detach())
return loss
16、IRG
论文链接:http://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class IRG(nn.Module):
'''
Knowledge Distillation via Instance Relationship Graph
http://openaccess.thecvf.com/content_CVPR_2019/papers/
Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf
The official code is written by Caffe
https://github.com/yufanLIU/IRG
'''
def __init__(self, w_irg_vert, w_irg_edge, w_irg_tran):
super(IRG, self).__init__()
self.w_irg_vert = w_irg_vert
self.w_irg_edge = w_irg_edge
self.w_irg_tran = w_irg_tran
def forward(self, irg_s, irg_t):
fm_s1, fm_s2, feat_s, out_s = irg_s
fm_t1, fm_t2, feat_t, out_t = irg_t
loss_irg_vert = F.mse_loss(out_s, out_t)
irg_edge_feat_s = self.euclidean_dist_feat(feat_s, squared=True)
irg_edge_feat_t = self.euclidean_dist_feat(feat_t, squared=True)
irg_edge_fm_s1 = self.euclidean_dist_fm(fm_s1, squared=True)
irg_edge_fm_t1 = self.euclidean_dist_fm(fm_t1, squared=True
)
irg_edge_fm_s2 = self.euclidean_dist_fm(fm_s2, squared=True)
irg_edge_fm_t2 = self.euclidean_dist_fm(fm_t2, squared=True)
loss_irg_edge = (F.mse_loss(irg_edge_feat_s, irg_edge_feat_t) +
F.mse_loss(irg_edge_fm_s1, irg_edge_fm_t1 ) +
F.mse_loss(irg_edge_fm_s2, irg_edge_fm_t2 )) / 3.0
irg_tran_s = self.euclidean_dist_fms(fm_s1, fm_s2, squared=True)
irg_tran_t = self.euclidean_dist_fms(fm_t1, fm_t2, squared=True)
loss_irg_tran = F.mse_loss(irg_tran_s, irg_tran_t)
# print(self.w_irg_vert * loss_irg_vert)
# print(self.w_irg_edge * loss_irg_edge)
# print(self.w_irg_tran * loss_irg_tran)
# print()
loss = (self.w_irg_vert * loss_irg_vert +
self.w_irg_edge * loss_irg_edge +
self.w_irg_tran * loss_irg_tran)
return loss
def euclidean_dist_fms(self, fm1, fm2, squared=False, eps=1e-12):
'''
Calculating the IRG Transformation, where fm1 precedes fm2 in the network.
'''
if fm1.size(2) > fm2.size(2):
fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
if fm1.size(1) 1):
fm2 = (fm2[:,0::2,:,:] + fm2[:,1::2,:,:]) / 2.0
fm1 = fm1.view(fm1.size(0), -1)
fm2 = fm2.view(fm2.size(0), -1)
fms_dist = torch.sum(torch.pow(fm1-fm2, 2), dim=-1).clamp(min=eps)
if not squared:
fms_dist = fms_dist.sqrt()
fms_dist = fms_dist / fms_dist.max()
return fms_dist
def euclidean_dist_fm(self, fm, squared=False, eps=1e-12):
'''
Calculating the IRG edge of feature map.
'''
fm = fm.view(fm.size(0), -1)
fm_square = fm.pow(2).sum(dim=1)
fm_prod = torch.mm(fm, fm.t())
fm_dist = (fm_square.unsqueeze(0) + fm_square.unsqueeze(1) - 2 * fm_prod).clamp(min=eps)
if not squared:
fm_dist = fm_dist.sqrt()
fm_dist = fm_dist.clone()
fm_dist[range(len(fm)), range(len(fm))] = 0
fm_dist = fm_dist / fm_dist.max()
return fm_dist
def euclidean_dist_feat(self, feat, squared=False, eps=1e-12):
'''
Calculating the IRG edge of feat.
'''
feat_square = feat.pow(2).sum(dim=1)
feat_prod = torch.mm(feat, feat.t())
feat_dist = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)
if not squared:
feat_dist = feat_dist.sqrt()
feat_dist = feat_dist.clone()
feat_dist[range(len(feat)), range(len(feat))] = 0
feat_dist = feat_dist / feat_dist.max()
return feat_dist
17、VID
论文链接:https://openaccess.thecvf.com/content_CVPR_2019/papers/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def conv1x1(in_channels, out_channels):
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=1,
padding=0, bias=False)
'''
Modified from https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/VID.py
'''
class VID(nn.Module):
'''
Variational Information Distillation for Knowledge Transfer
https://zpascal.net/cvpr2019/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf
'''
def __init__(self, in_channels, mid_channels, out_channels, init_var, eps=1e-6):
super(VID, self).__init__()
self.eps = eps
self.regressor = nn.Sequential(*[
conv1x1(in_channels, mid_channels),
# nn.BatchNorm2d(mid_channels),
nn.ReLU(),
conv1x1(mid_channels, mid_channels),
# nn.BatchNorm2d(mid_channels),
nn.ReLU(),
conv1x1(mid_channels, out_channels),
])
self.alpha = nn.Parameter(
np.log(np.exp(init_var-eps)-1.0) * torch.ones(out_channels)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.BatchNorm2d):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
def forward(self, fm_s, fm_t):
pred_mean = self.regressor(fm_s)
pred_var = torch.log(1.0+torch.exp(self.alpha)) + self.eps
pred_var = pred_var.view(1, -1, 1, 1)
neg_log_prob = 0.5 * (torch.log(pred_var) + (pred_mean-fm_t)**2 / pred_var)
loss = torch.mean(neg_log_prob)
return loss
18、OFD
论文链接:http://openaccess.thecvf.com/content_ICCV_2019/papers/Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
代码:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
'''
Modified from https://github.com/clovaai/overhaul-distillation/blob/master/CIFAR-100/distiller.py
'''
class OFD(nn.Module):
'''
A Comprehensive Overhaul of Feature Distillation
http://openaccess.thecvf.com/content_ICCV_2019/papers/
Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
'''
def __init__(self, in_channels, out_channels):
super(OFD, self).__init__()
self.connector = nn.Sequential(*[
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels)
])
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, fm_s, fm_t):
margin = self.get_margin(fm_t)
fm_t = torch.max(fm_t, margin)
fm_s = self.connector(fm_s)
mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
loss = torch.mean((fm_s - fm_t)**2 * mask)
return loss
def get_margin(self, fm, eps=1e-6):
mask = (fm 0.0).float()
masked_fm = fm * mask
margin = masked_fm.sum(dim=(0