輕量級CNN模型mobilenet v1

mobilenet v1

論文解讀

論文地址:https://arxiv.org/abs/1704.04861html

核心思想就是經過depthwise conv替代普通conv.

有關depthwise conv能夠參考http://www.javashuo.com/article/p-ncitmnlb-dx.htmlgit

模型結構:

相似於vgg這種堆疊的結構.github

每一層的運算量

能夠看到,運算量並非與參數數量絕對成正比,固然總體趨勢而言,參數量更少的模型會運算更快.網絡

代碼實現

https://github.com/marvis/pytorch-mobilenetide

網絡結構:3d

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
    
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            conv_bn(  3,  32, 2), 
            conv_dw( 32,  64, 1),
            conv_dw( 64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7),
        )
        self.fc = nn.Linear(1024, 1000)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

參考論文中的結構,第一層是普通的卷積層,後面接的都是可分離卷積.code

這裏注意groups參數的用法. 當groups=輸入channel數目時,即對每一個channel分別作卷積.默認groups=1,此時即爲普通卷積.
orm

訓練僞代碼htm

# create model
model = Net()

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()

optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)


# load data
train_loader = torch.utils.data.DataLoader()

# train
for every epoch:
    input,target=get_from_data
    
    #前向傳播獲得預測值
    output = model(input_var)
    
    #計算loss
    loss = criterion(output, target_var)
        
    #反向傳播更新網絡參數
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
相關文章
相關標籤/搜索