前言
具体原理已经讲过了,见上回的推文。
深度学习算法优化系列七 | ICCV 2017的一篇模型剪枝论文,也是2019年众多开源剪枝项目的理论基础
。这篇文章是从源码实战的角度来解释模型剪枝,源码来自:https://github.com/Eric-mingjie/network-slimming 。我这里主要是结合源码来分析每个模型的具体剪枝过程,希望能给你剪枝自己的模型一些启发。
稀疏训练
论文的想法是对于每一个通道都引入一个缩放因子
,然后和通道的输出相乘。接着联合训练网络权重和这些缩放因子,最后将小缩放因子的通道直接移除,微调剪枝后的网络,特别地,目标函数被定义为:
其中
代表训练数据和标签,
是网络的可训练参数,第一项是CNN的训练损失函数。
是在缩放因子上的乘法项,
是两项的平衡因子。论文的实验过程中选择
,即
正则化,这也被广泛的应用于稀疏化。次梯度下降法作为不平滑(不可导)的L1惩罚项的优化方法,另一个建议是使用平滑的L1正则项取代L1惩罚项,尽量避免在不平滑的点使用次梯度。
在
main.py
的实现中支持了稀疏训练,其中下面这行代码即添加了稀疏训练的惩罚系数
,注意
是作用在BN层的缩放系数上的:
parser.add_argument('--s', type=float, default=0.0001, help='scale sparse rate (default: 0.0001)')
因此BN层的更新也要相应的加上惩罚项,代码如下:
def updateBN(): for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.weight.grad.data.add_(args.s*torch.sign(m.weight.data)) # L1
最后训练,测试,保存Basline模型(包含VGG16,Resnet-164,DenseNet40)的代码如下,代码很常规就不过多解释这一节了:
def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = F.cross_entropy(output, target) pred = output.data.max(1, keepdim=True)[1] loss.backward() if args.sr: updateBN() optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data[0])) def test(): model.eval() test_loss = 0 correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) test_loss += F.cross_entropy(output, target, size_average=False).data[0] # sum up batch loss pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset)) def save_checkpoint(state, is_best, filepath): torch.save(state, os.path.join(filepath, 'checkpoint.pth.tar')) if is_best: shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar')) best_prec1 = 0. for epoch in range(args.start_epoch, args.epochs): if epoch in [args.epochs*0.5, args.epochs*0.75]: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 train(epoch) prec1 = test() is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, filepath=args.save) print("Best accuracy: "+str(best_prec1))
VGG16的剪枝
代码为工程目录下的
vggprune.py
。剪枝的具体步骤如下:
模型加载
加载需要剪枝的模型,也即是稀疏训练得到的BaseLine模型,代码如下,其中
args.depth
用于指定VGG模型的深度,一般为
16
和
19
:
model = vgg(dataset=args.dataset, depth=args.depth) if args.cuda: model.cuda() if args.model: if os.path.isfile(args.model): print("=> loading checkpoint '{}'".format(args.model)) checkpoint = torch.load(args.model) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" .format(args.model, checkpoint['epoch'], best_prec1)) else: print("=> no checkpoint found at '{}'".format(args.resume)) print(model)
预剪枝
首先确定剪枝的全局阈值,然后根据阈值得到剪枝后的网络每层的通道数
cfg_mask
,这个
cfg_mask
就可以确定我们剪枝后的模型的结构了,注意这个过程只是确定每一层那一些索引的通道要被剪枝掉并获得
cfg_mask
,还没有真正的执行剪枝操作。我给代码加了部分注释,应该不难懂。
# 计算需要剪枝的变量个数total total = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): total += m.weight.data.shape[0] # 确定剪枝的全局阈值 bn = torch.zeros(total) index = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): size = m.weight.data.shape[0] bn[index:(index+size)] = m.weight.data.abs().clone() index += size # 按照权值大小排序 y, i = torch.sort(bn) thre_index = int(total * args.percent) # 确定要剪枝的阈值 thre = y[thre_index] #********************************预剪枝*********************************# pruned = 0 cfg = [] cfg_mask = [] for k, m in enumerate(model.modules()): if isinstance(m, nn.BatchNorm2d): weight_copy = m.weight.data.abs().clone() # 要保留的通道标记Mask图 mask = weight_copy.gt(thre).float().cuda() # 剪枝掉的通道数个数 pruned = pruned + mask.shape[0] - torch.sum(mask) m.weight.data.mul_(mask) m.bias.data.mul_(mask) cfg.append(int(torch.sum(mask))) cfg_mask.append(mask.clone()) print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. format(k, mask.shape[0], int(torch.sum(mask)))) elif isinstance(m, nn.MaxPool2d): cfg.append('M') pruned_ratio = pruned/total print('Pre-processing Successful!')
对预剪枝后的模型进行测试
没什么好说的,看一下我的代码注释好啦。
# simple test model after Pre-processing prune (simple set BN scales to zeros) #********************************预剪枝后model测试*********************************# def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} # 加载测试数据 if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), # 对R, G,B通道应该减的均值 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(),