論文地址:https://arxiv.org/abs/1704.04861html
核心思想就是經過depthwise conv替代普通conv.
有關depthwise conv能夠參考http://www.javashuo.com/article/p-ncitmnlb-dx.htmlgit
模型結構:
相似於vgg這種堆疊的結構.github
每一層的運算量
能夠看到,運算量並非與參數數量絕對成正比,固然總體趨勢而言,參數量更少的模型會運算更快.網絡
https://github.com/marvis/pytorch-mobilenetide
網絡結構:3d
class Net(nn.Module): def __init__(self): super(Net, self).__init__() def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) ) def conv_dw(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True), ) self.model = nn.Sequential( conv_bn( 3, 32, 2), conv_dw( 32, 64, 1), conv_dw( 64, 128, 2), conv_dw(128, 128, 1), conv_dw(128, 256, 2), conv_dw(256, 256, 1), conv_dw(256, 512, 2), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 1024, 2), conv_dw(1024, 1024, 1), nn.AvgPool2d(7), ) self.fc = nn.Linear(1024, 1000) def forward(self, x): x = self.model(x) x = x.view(-1, 1024) x = self.fc(x) return x
參考論文中的結構,第一層是普通的卷積層,後面接的都是可分離卷積.code
這裏注意groups參數的用法. 當groups=輸入channel數目時,即對每一個channel分別作卷積.默認groups=1,此時即爲普通卷積.
orm
訓練僞代碼htm
# create model model = Net() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # load data train_loader = torch.utils.data.DataLoader() # train for every epoch: input,target=get_from_data #前向傳播獲得預測值 output = model(input_var) #計算loss loss = criterion(output, target_var) #反向傳播更新網絡參數 optimizer.zero_grad() loss.backward() optimizer.step()