語義分割網絡U-net詳解

U-net介紹

今天來介紹一個經典的語義分割網絡U-net, 它於2015年提出,最初應用在醫療影像分割任務上,因爲效果很好,以後被普遍應用在各類分割任務中。至今已衍生出許多基於U-net的分割模型。
U-net是典型的Encoder-Decoder結構,encoder進行特徵提取,decoder
進行上採樣。因爲數據的限制,U-net在訓練階段使用了大量的數據加強操做,最後獲得了不錯的效果。網絡

U-net網絡結構

U-net的網絡結構以下所示。左邊爲encoder部分,對輸入進行下采樣,下采樣經過最大池化實現;右邊爲decoder部分,對encoder的輸出進行上採樣,恢復分辨率,上採樣經過Upsample實現;中間爲跳躍鏈接(Skip-connect),進行特徵融合。因爲整個網絡形似一個"U",因此稱爲U-net。
網絡中除了最後的輸出層,其他全部卷積層均爲3 * 3卷積。
未命名圖片.pngide

U-net代碼實現

import torch as t
import torch.nn as nn

class  DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.dconv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            
            # inplace設爲True能夠節省顯存/內存
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, img):
        return self.dconv(img)
        
# 下采樣
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2, 2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, img):
        return self.down(img)

# 上採樣
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        # ConvTranspose2D 有可學習的參數, 會在訓練過程當中不斷調整參數。會增長模型的複雜度,可能會形成過擬合
        # Upsample 沒有可學習的參數
        # 和Conv2d和MaxPooling2d的區別同樣
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # pading 保證x1和x2的大小同樣
        dx = x2.shape[3] - x1.shape[3]
        dy = x2.shape[2] - x1.shape[2]
        x1 = nn.functional.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2])
        # 通道合併
        x = t.cat([x1, x2], dim=1)
        return self.conv(x)


# 主網絡
class CrackUnet(nn.Module):
    def __init__(self, channels, classes, bilinear=True):
        super(CrackUnet, self).__init__()
        self.channels = channels
        self.classes = classes
        self.bilinear = bilinear
        # 
        self.inconv = DoubleConv(self.channels, 64)

        # 4個下采樣層
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        # 4個上採樣層, 採用雙線性採樣
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outconv = nn.Conv2d(64, channels, 1)

    def forward(self, img):
        img = self.inconv(img)
        down1 = self.down1(img)
        down2 = self.down2(down1)
        down3 = self.down3(down2)
        down4 = self.down4(down3)
        x = self.up1(down4, down3)
        del down4
        del down3
        x = self.up2(x, down2)
        del down2
        x = self.up3(x, down1)
        del down1
        x = self.up5(x, img)
        del img
        return self.outconv(x)

總結

U-net結構簡單穩定,是典型的下采樣+上採樣的分割網絡結構。尤爲在數據集較小的時候,推薦使用。學習

相關文章
相關標籤/搜索