[論文理解] Learning Efficient Convolutional Networks through Network Slimming

Learning Efficient Convolutional Networks through Network Slimming

簡介

這是我看的第一篇模型壓縮方面的論文,應該也算比較出名的一篇吧,由於很早就對模型壓縮比較感興趣,因此抽了個時間看了一篇,代碼也本身實現了一下,以爲仍是挺容易的。這篇文章就模型壓縮問題提出了一種剪枝針對BN層的剪枝方法,做者經過利用BN層的權重來評估輸入channel的score,經過對score進行threshold過濾到score低的channel,在鏈接的時候這些score過小的channel的神經元就不參與鏈接,而後逐層剪枝,就達到了壓縮效果。python

就我我的而言,如今經常使用的attention mechanism我認爲能夠用來評估channel的score能夠作一作文章,可是確定是針對特定任務而言的,後面我會本身作一作實驗,利用attention機制來模型剪枝。數組

方法

本文的方法如圖所示,即網絡

  1. 給定要保留層的比例,記下全部BN層大於該比例的權重
  2. 對模型先進行BN層的剪枝,即丟棄小於上面權重比例的參數
  3. 對模型進行卷積層剪枝(由於一般是卷積層後+BN,因此知道由先後的BN層能夠知道卷積層權重size),對卷積層的size作匹配先後BN的對應channel元素丟棄的剪枝。
  4. 對FC層進行剪枝

感受說不太清楚,可是一看代碼就全懂了。。app

代碼

我本身實現了一下。spa

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
from torchsummary import summary


class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(3,16,kernel_size = 3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16,32,kernel_size = 3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32,64,kernel_size = 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,128,kernel_size = 3),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.maxpool = nn.MaxPool2d(216)
        self.fc = nn.Linear(128,3)

    def forward(self,x):
        x = self.convnet(x)
        x = self.maxpool(x)
        x = x.view(-1,x.size(1))
        return self.fc(x)

if __name__ == "__main__":
    net = Net()
    net_new = Net()
    idxs = []
    idxs.append(range(3))
    for module in net.modules():
        if type(module) is nn.BatchNorm2d:
            weight = module.weight.data
            n = weight.size(0)
            y,idx = torch.sort(weight)
            n = int(0.8 * n) 
            idxs.append(idx[:n])
            #print(module.weight.data.size())
    i=1
    for module in net_new.modules():
        if type(module) is nn.Conv2d:
            weight = module.weight.data.clone()
            weight = weight[idxs[i],:,:,:]
            weight = weight[:,idxs[i-1],:,:]
            module.bias.data = module.bias.data[idxs[i]]
            module.weight.data = weight
        elif type(module) is nn.BatchNorm2d:
            weight = module.weight.data.clone()
            bias = module.bias.data.clone()
            running_mean = module.running_mean.data.clone()
            running_var = module.running_var.data.clone()
            
            weight = weight[idxs[i]]
            bias = bias[idxs[i]]
            running_mean = running_mean[idxs[i]]
            running_var = running_var[idxs[i]]

            module.weight.data = weight
            module.bias.data = bias
            module.running_var.data = running_var
            module.running_mean.data = running_mean
            i += 1
        elif type(module) is nn.Linear:
            #print(module.weight.data.size())
            module.weight.data = module.weight.data[:,idxs[-1]]
            
    summary(net_new,(3,224,224),device = "cpu")
'''
這是對vgg的剪枝例子,文章中說了對其餘網絡的slimming例子
'''
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.models import vgg19
from models import *


# Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar100',
                    help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
                    help='input batch size for testing (default: 256)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--depth', type=int, default=19,
                    help='depth of the vgg')
parser.add_argument('--percent', type=float, default=0.5,
                    help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH',
                    help='path to the model (default: none)')
parser.add_argument('--save', default='', type=str, metavar='PATH',
                    help='path to save pruned model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if not os.path.exists(args.save):
    os.makedirs(args.save)

model = vgg19(dataset=args.dataset, depth=args.depth)
if args.cuda:
    model.cuda()

if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.model, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

print(model)
total = 0
for m in model.modules():# 遍歷vgg的每一個module
    if isinstance(m, nn.BatchNorm2d): # 若是發現BN層
        total += m.weight.data.shape[0] # BN層的特徵數目,total就是全部BN層的特徵數目總和

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size # 把全部BN層的權重給CLONE下來

y, i = torch.sort(bn) # 這些權重排序
thre_index = int(total * args.percent) # 要保留的數量
thre = y[thre_index] # 最小的權重值

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre).float().cuda()# 小於權重thre的爲0,大於的爲1
        pruned = pruned + mask.shape[0] - torch.sum(mask) # 被剪枝的權重的總數
        m.weight.data.mul_(mask) # 權重對應相乘
        m.bias.data.mul_(mask) # 偏置也對應相乘
        cfg.append(int(torch.sum(mask))) #第幾個batchnorm保留多少。
        cfg_mask.append(mask.clone()) # 第幾個batchnorm 保留的weight
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/total # 剪枝比例

print('Pre-processing Successful!')

# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

acc = test(model)

# Make real prune
print(cfg)
newmodel = vgg(dataset=args.dataset, cfg=cfg)
if args.cuda:
    newmodel.cuda()
# torch.nelement() 能夠統計張量的個數
num_parameters = sum([param.nelement() for param in newmodel.parameters()]) # 元素個數,好比對於張量shape爲(20,3,3,3),那麼他的元素個數就是四者乘積也就是20*27 = 540 
# 能夠用來統計參數量 嘿嘿
savepath = os.path.join(args.save, "prune.txt")
with open(savepath, "w") as fp:
    fp.write("Configuration: \n"+str(cfg)+"\n")
    fp.write("Number of parameters: \n"+str(num_parameters)+"\n")
    fp.write("Test accuracy: \n"+str(acc))

layer_id_in_cfg = 0 # 第幾層
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg] # 
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.BatchNorm2d):
        # np.where 返回的是全部知足條件的數的索引,有多少個知足條件的數就有多少個索引,絕對的索引
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 大於0的全部數據的索引,squeeze變成向量
        if idx1.size == 1: # 只有一個要變成數組的1個
            idx1 = np.resize(idx1,(1,))
        m1.weight.data = m0.weight.data[idx1.tolist()].clone() # 用通過剪枝的替換原來的
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        layer_id_in_cfg += 1 # 下一層
        start_mask = end_mask.clone() # 當前在處理的層的mask
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d): # 對卷積層進行剪枝
        # 卷積後面會接bn
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # 這個剪枝牛B了。。
        w1 = w1[idx1.tolist(), :, :, :].clone() # 最終的權重矩陣
        m1.weight.data = w1.clone()
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()

torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))

print(newmodel)
model = newmodel
test(model)
相關文章
相關標籤/搜索