专栏名称: AI开发者
AI研习社,雷锋网旗下关注AI开发技巧及技术教程订阅号。
目录
相关文章推荐
爱可可-爱生活  ·  【[157星]YOLOE:实时“看见”一切的 ... ·  23 小时前  
宝玉xp  ·  50位 a16z 合伙人对 2025 ... ·  2 天前  
爱可可-爱生活  ·  [CL]《HieroLM: ... ·  2 天前  
爱可可-爱生活  ·  【[2.3k星]Streamyfin:一个简 ... ·  3 天前  
宝玉xp  ·  Steve Yegge ... ·  3 天前  
51好读  ›  专栏  ›  AI开发者

专栏 | 深度学习算法优化系列八 | VGG,ResNet,DenseNe模型剪枝代码实战

AI开发者  · 公众号  · AI  · 2020-03-01 20:40

正文


点击上方“蓝字”关注“AI开发者”


本文来自 @BBuf 的社区专栏 GiantPandaCV ,文末扫码即可订阅专栏。

前言

具体原理已经讲过了,见上回的推文。 深度学习算法优化系列七 | ICCV 2017的一篇模型剪枝论文,也是2019年众多开源剪枝项目的理论基础 。这篇文章是从源码实战的角度来解释模型剪枝,源码来自:https://github.com/Eric-mingjie/network-slimming 。我这里主要是结合源码来分析每个模型的具体剪枝过程,希望能给你剪枝自己的模型一些启发。

稀疏训练

论文的想法是对于每一个通道都引入一个缩放因子 ,然后和通道的输出相乘。接着联合训练网络权重和这些缩放因子,最后将小缩放因子的通道直接移除,微调剪枝后的网络,特别地,目标函数被定义为:

其中 代表训练数据和标签, 是网络的可训练参数,第一项是CNN的训练损失函数。 是在缩放因子上的乘法项, 是两项的平衡因子。论文的实验过程中选择 ,即 正则化,这也被广泛的应用于稀疏化。次梯度下降法作为不平滑(不可导)的L1惩罚项的优化方法,另一个建议是使用平滑的L1正则项取代L1惩罚项,尽量避免在不平滑的点使用次梯度。

main.py 的实现中支持了稀疏训练,其中下面这行代码即添加了稀疏训练的惩罚系数 ,注意 是作用在BN层的缩放系数上的:

parser.add_argument('--s', type=float, default=0.0001,
help='scale sparse rate (default: 0.0001)')

因此BN层的更新也要相应的加上惩罚项,代码如下:

def updateBN():
   for m in model.modules():
       if isinstance(m, nn.BatchNorm2d):
      m.weight.grad.data.add_(args.s*torch.sign(m.weight.data)) # L1

最后训练,测试,保存Basline模型(包含VGG16,Resnet-164,DenseNet40)的代码如下,代码很常规就不过多解释这一节了:

def train(epoch):
   model.train()
   for batch_idx, (data, target) in enumerate(train_loader):
       if args.cuda:
           data, target = data.cuda(), target.cuda()
       data, target = Variable(data), Variable(target)
       optimizer.zero_grad()
       output = model(data)
       loss = F.cross_entropy(output, target)
       pred = output.data.max(1, keepdim=True)[1]
       loss.backward()
       if args.sr:
           updateBN()
       optimizer.step()
       if batch_idx % args.log_interval == 0:
           print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
               epoch, batch_idx * len(data), len(train_loader.dataset),
               100. * batch_idx / len(train_loader), loss.data[0]))

def test():
   model.eval()
   test_loss = 0
   correct = 0
   for data, target in test_loader:
       if args.cuda:
           data, target = data.cuda(), target.cuda()
       data, target = Variable(data, volatile=True), Variable(target)
       output = model(data)
       test_loss += F.cross_entropy(output, target, size_average=False).data[0] # sum up batch loss
       pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
       correct += pred.eq(target.data.view_as(pred)).cpu().sum()

   test_loss /= len(test_loader.dataset)
   print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
       test_loss, correct, len(test_loader.dataset),
       100. * correct / len(test_loader.dataset)))
   return correct / float(len(test_loader.dataset))

def save_checkpoint(state, is_best, filepath):
   torch.save(state, os.path.join(filepath, 'checkpoint.pth.tar'))
   if is_best:
       shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar'))

best_prec1 = 0.
for epoch in range(args.start_epoch, args.epochs):
   if epoch in [args.epochs*0.5, args.epochs*0.75]:
       for param_group in optimizer.param_groups:
           param_group['lr'] *= 0.1
   train(epoch)
   prec1 = test()
   is_best = prec1 > best_prec1
   best_prec1 = max(prec1, best_prec1)
   save_checkpoint({
       'epoch': epoch + 1,
       'state_dict': model.state_dict(),
       'best_prec1': best_prec1,
       'optimizer': optimizer.state_dict(),
   }, is_best, filepath=args.save)

print("Best accuracy: "+str(best_prec1))

VGG16的剪枝

代码为工程目录下的 vggprune.py 。剪枝的具体步骤如下:

模型加载

加载需要剪枝的模型,也即是稀疏训练得到的BaseLine模型,代码如下,其中 args.depth 用于指定VGG模型的深度,一般为 16 19

model = vgg(dataset=args.dataset, depth=args.depth)
if args.cuda:
   model.cuda()

if args.model:
   if os.path.isfile(args.model):
       print("=> loading checkpoint '{}'".format(args.model))
       checkpoint = torch.load(args.model)
       args.start_epoch = checkpoint['epoch']
       best_prec1 = checkpoint['best_prec1']
       model.load_state_dict(checkpoint['state_dict'])
       print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
             .format(args.model, checkpoint['epoch'], best_prec1))
   else:
       print("=> no checkpoint found at '{}'".format(args.resume))

print(model)

预剪枝

首先确定剪枝的全局阈值,然后根据阈值得到剪枝后的网络每层的通道数 cfg_mask ,这个 cfg_mask 就可以确定我们剪枝后的模型的结构了,注意这个过程只是确定每一层那一些索引的通道要被剪枝掉并获得 cfg_mask ,还没有真正的执行剪枝操作。我给代码加了部分注释,应该不难懂。

# 计算需要剪枝的变量个数total
total = 0
for m in model.modules():
   if isinstance(m, nn.BatchNorm2d):
       total += m.weight.data.shape[0]

# 确定剪枝的全局阈值
bn = torch.zeros(total)
index = 0
for m in model.modules():
   if isinstance(m, nn.BatchNorm2d):
       size = m.weight.data.shape[0]
       bn[index:(index+size)] = m.weight.data.abs().clone()
       index += size
# 按照权值大小排序
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
# 确定要剪枝的阈值
thre = y[thre_index]
#********************************预剪枝*********************************#
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
   if isinstance(m, nn.BatchNorm2d):
       weight_copy = m.weight.data.abs().clone()
       # 要保留的通道标记Mask图
       mask = weight_copy.gt(thre).float().cuda()
       # 剪枝掉的通道数个数
       pruned = pruned + mask.shape[0] - torch.sum(mask)
       m.weight.data.mul_(mask)
       m.bias.data.mul_(mask)
       cfg.append(int(torch.sum(mask)))
       cfg_mask.append(mask.clone())
       print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
           format(k, mask.shape[0], int(torch.sum(mask))))
   elif isinstance(m, nn.MaxPool2d):
       cfg.append('M')

pruned_ratio = pruned/total

print('Pre-processing Successful!')

对预剪枝后的模型进行测试

没什么好说的,看一下我的代码注释好啦。

# simple test model after Pre-processing prune (simple set BN scales to zeros)
#********************************预剪枝后model测试*********************************#
def test(model):
   kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
   # 加载测试数据
   if args.dataset == 'cifar10':
       test_loader = torch.utils.data.DataLoader(
           datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
               transforms.ToTensor(),
               # 对R, G,B通道应该减的均值
               transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
           batch_size=args.test_batch_size, shuffle=True, **kwargs)
   elif args.dataset == 'cifar100':
       test_loader = torch.utils.data.DataLoader(
           datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
               transforms.ToTensor(),






请到「今天看啥」查看全文