今天來介紹一個經典的語義分割網絡U-net, 它於2015年提出,最初應用在醫療影像分割任務上,因爲效果很好,以後被普遍應用在各類分割任務中。至今已衍生出許多基於U-net的分割模型。
U-net是典型的Encoder-Decoder結構,encoder進行特徵提取,decoder
進行上採樣。因爲數據的限制,U-net在訓練階段使用了大量的數據加強操做,最後獲得了不錯的效果。網絡
U-net的網絡結構以下所示。左邊爲encoder部分,對輸入進行下采樣,下采樣經過最大池化實現;右邊爲decoder部分,對encoder的輸出進行上採樣,恢復分辨率,上採樣經過Upsample實現;中間爲跳躍鏈接(Skip-connect),進行特徵融合。因爲整個網絡形似一個"U",因此稱爲U-net。
網絡中除了最後的輸出層,其他全部卷積層均爲3 * 3卷積。
ide
import torch as t import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.dconv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels), # inplace設爲True能夠節省顯存/內存 nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, img): return self.dconv(img) # 下采樣 class Down(nn.Module): def __init__(self, in_channels, out_channels): super(Down, self).__init__() self.down = nn.Sequential( nn.MaxPool2d(2, 2), DoubleConv(in_channels, out_channels) ) def forward(self, img): return self.down(img) # 上採樣 class Up(nn.Module): def __init__(self, in_channels, out_channels, bilinear=True): super(Up, self).__init__() # ConvTranspose2D 有可學習的參數, 會在訓練過程當中不斷調整參數。會增長模型的複雜度,可能會形成過擬合 # Upsample 沒有可學習的參數 # 和Conv2d和MaxPooling2d的區別同樣 if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # pading 保證x1和x2的大小同樣 dx = x2.shape[3] - x1.shape[3] dy = x2.shape[2] - x1.shape[2] x1 = nn.functional.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2]) # 通道合併 x = t.cat([x1, x2], dim=1) return self.conv(x) # 主網絡 class CrackUnet(nn.Module): def __init__(self, channels, classes, bilinear=True): super(CrackUnet, self).__init__() self.channels = channels self.classes = classes self.bilinear = bilinear # self.inconv = DoubleConv(self.channels, 64) # 4個下采樣層 self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) self.down4 = Down(512, 512) # 4個上採樣層, 採用雙線性採樣 self.up1 = Up(1024, 256, bilinear) self.up2 = Up(512, 128, bilinear) self.up3 = Up(256, 64, bilinear) self.up4 = Up(128, 64, bilinear) self.outconv = nn.Conv2d(64, channels, 1) def forward(self, img): img = self.inconv(img) down1 = self.down1(img) down2 = self.down2(down1) down3 = self.down3(down2) down4 = self.down4(down3) x = self.up1(down4, down3) del down4 del down3 x = self.up2(x, down2) del down2 x = self.up3(x, down1) del down1 x = self.up5(x, img) del img return self.outconv(x)
U-net結構簡單穩定,是典型的下采樣+上採樣的分割網絡結構。尤爲在數據集較小的時候,推薦使用。學習