Learning Efficient Convolutional Networks through Network Slimming
簡介
這是我看的第一篇模型壓縮方面的論文,應該也算比較出名的一篇吧,由於很早就對模型壓縮比較感興趣,因此抽了個時間看了一篇,代碼也本身實現了一下,以爲仍是挺容易的。這篇文章就模型壓縮問題提出了一種剪枝針對BN層的剪枝方法,做者經過利用BN層的權重來評估輸入channel的score,經過對score進行threshold過濾到score低的channel,在鏈接的時候這些score過小的channel的神經元就不參與鏈接,而後逐層剪枝,就達到了壓縮效果。python
就我我的而言,如今經常使用的attention mechanism我認爲能夠用來評估channel的score能夠作一作文章,可是確定是針對特定任務而言的,後面我會本身作一作實驗,利用attention機制來模型剪枝。數組
方法
本文的方法如圖所示,即網絡
- 給定要保留層的比例,記下全部BN層大於該比例的權重
- 對模型先進行BN層的剪枝,即丟棄小於上面權重比例的參數
- 對模型進行卷積層剪枝(由於一般是卷積層後+BN,因此知道由先後的BN層能夠知道卷積層權重size),對卷積層的size作匹配先後BN的對應channel元素丟棄的剪枝。
- 對FC層進行剪枝
感受說不太清楚,可是一看代碼就全懂了。。app
代碼
我本身實現了一下。spa
import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import vgg19 from torchsummary import summary class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.convnet = nn.Sequential( nn.Conv2d(3,16,kernel_size = 3), nn.BatchNorm2d(16), nn.ReLU(), nn.Conv2d(16,32,kernel_size = 3), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32,64,kernel_size = 3), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64,128,kernel_size = 3), nn.BatchNorm2d(128), nn.ReLU() ) self.maxpool = nn.MaxPool2d(216) self.fc = nn.Linear(128,3) def forward(self,x): x = self.convnet(x) x = self.maxpool(x) x = x.view(-1,x.size(1)) return self.fc(x) if __name__ == "__main__": net = Net() net_new = Net() idxs = [] idxs.append(range(3)) for module in net.modules(): if type(module) is nn.BatchNorm2d: weight = module.weight.data n = weight.size(0) y,idx = torch.sort(weight) n = int(0.8 * n) idxs.append(idx[:n]) #print(module.weight.data.size()) i=1 for module in net_new.modules(): if type(module) is nn.Conv2d: weight = module.weight.data.clone() weight = weight[idxs[i],:,:,:] weight = weight[:,idxs[i-1],:,:] module.bias.data = module.bias.data[idxs[i]] module.weight.data = weight elif type(module) is nn.BatchNorm2d: weight = module.weight.data.clone() bias = module.bias.data.clone() running_mean = module.running_mean.data.clone() running_var = module.running_var.data.clone() weight = weight[idxs[i]] bias = bias[idxs[i]] running_mean = running_mean[idxs[i]] running_var = running_var[idxs[i]] module.weight.data = weight module.bias.data = bias module.running_var.data = running_var module.running_mean.data = running_mean i += 1 elif type(module) is nn.Linear: #print(module.weight.data.size()) module.weight.data = module.weight.data[:,idxs[-1]] summary(net_new,(3,224,224),device = "cpu")
''' 這是對vgg的剪枝例子,文章中說了對其餘網絡的slimming例子 ''' import os import argparse import numpy as np import torch import torch.nn as nn from torch.autograd import Variable from torchvision import datasets, transforms from torchvision.models import vgg19 from models import * # Prune settings parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') parser.add_argument('--dataset', type=str, default='cifar100', help='training dataset (default: cifar10)') parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for testing (default: 256)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--depth', type=int, default=19, help='depth of the vgg') parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)') parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)') parser.add_argument('--save', default='', type=str, metavar='PATH', help='path to save pruned model (default: none)') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() if not os.path.exists(args.save): os.makedirs(args.save) model = vgg19(dataset=args.dataset, depth=args.depth) if args.cuda: model.cuda() if args.model: if os.path.isfile(args.model): print("=> loading checkpoint '{}'".format(args.model)) checkpoint = torch.load(args.model) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" .format(args.model, checkpoint['epoch'], best_prec1)) else: print("=> no checkpoint found at '{}'".format(args.resume)) print(model) total = 0 for m in model.modules():# 遍歷vgg的每一個module if isinstance(m, nn.BatchNorm2d): # 若是發現BN層 total += m.weight.data.shape[0] # BN層的特徵數目,total就是全部BN層的特徵數目總和 bn = torch.zeros(total) index = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): size = m.weight.data.shape[0] bn[index:(index+size)] = m.weight.data.abs().clone() index += size # 把全部BN層的權重給CLONE下來 y, i = torch.sort(bn) # 這些權重排序 thre_index = int(total * args.percent) # 要保留的數量 thre = y[thre_index] # 最小的權重值 pruned = 0 cfg = [] cfg_mask = [] for k, m in enumerate(model.modules()): if isinstance(m, nn.BatchNorm2d): weight_copy = m.weight.data.abs().clone() mask = weight_copy.gt(thre).float().cuda()# 小於權重thre的爲0,大於的爲1 pruned = pruned + mask.shape[0] - torch.sum(mask) # 被剪枝的權重的總數 m.weight.data.mul_(mask) # 權重對應相乘 m.bias.data.mul_(mask) # 偏置也對應相乘 cfg.append(int(torch.sum(mask))) #第幾個batchnorm保留多少。 cfg_mask.append(mask.clone()) # 第幾個batchnorm 保留的weight print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. format(k, mask.shape[0], int(torch.sum(mask)))) elif isinstance(m, nn.MaxPool2d): cfg.append('M') pruned_ratio = pruned/total # 剪枝比例 print('Pre-processing Successful!') # simple test model after Pre-processing prune (simple set BN scales to zeros) def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) else: raise ValueError("No valid dataset is given.") model.eval() correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset)) acc = test(model) # Make real prune print(cfg) newmodel = vgg(dataset=args.dataset, cfg=cfg) if args.cuda: newmodel.cuda() # torch.nelement() 能夠統計張量的個數 num_parameters = sum([param.nelement() for param in newmodel.parameters()]) # 元素個數,好比對於張量shape爲(20,3,3,3),那麼他的元素個數就是四者乘積也就是20*27 = 540 # 能夠用來統計參數量 嘿嘿 savepath = os.path.join(args.save, "prune.txt") with open(savepath, "w") as fp: fp.write("Configuration: \n"+str(cfg)+"\n") fp.write("Number of parameters: \n"+str(num_parameters)+"\n") fp.write("Test accuracy: \n"+str(acc)) layer_id_in_cfg = 0 # 第幾層 start_mask = torch.ones(3) end_mask = cfg_mask[layer_id_in_cfg] # for [m0, m1] in zip(model.modules(), newmodel.modules()): if isinstance(m0, nn.BatchNorm2d): # np.where 返回的是全部知足條件的數的索引,有多少個知足條件的數就有多少個索引,絕對的索引 idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 大於0的全部數據的索引,squeeze變成向量 if idx1.size == 1: # 只有一個要變成數組的1個 idx1 = np.resize(idx1,(1,)) m1.weight.data = m0.weight.data[idx1.tolist()].clone() # 用通過剪枝的替換原來的 m1.bias.data = m0.bias.data[idx1.tolist()].clone() m1.running_mean = m0.running_mean[idx1.tolist()].clone() m1.running_var = m0.running_var[idx1.tolist()].clone() layer_id_in_cfg += 1 # 下一層 start_mask = end_mask.clone() # 當前在處理的層的mask if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC end_mask = cfg_mask[layer_id_in_cfg] elif isinstance(m0, nn.Conv2d): # 對卷積層進行剪枝 # 卷積後面會接bn idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) if idx0.size == 1: idx0 = np.resize(idx0, (1,)) if idx1.size == 1: idx1 = np.resize(idx1, (1,)) w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # 這個剪枝牛B了。。 w1 = w1[idx1.tolist(), :, :, :].clone() # 最終的權重矩陣 m1.weight.data = w1.clone() elif isinstance(m0, nn.Linear): idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) if idx0.size == 1: idx0 = np.resize(idx0, (1,)) m1.weight.data = m0.weight.data[:, idx0].clone() m1.bias.data = m0.bias.data.clone() torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) print(newmodel) model = newmodel test(model)