點擊上方「AI公園」,關注公衆號,選擇加「星標「或「置頂」git
做者:Sergey Kolchenkoweb
編譯:ronghuaiyang
微信
在不一樣的任務上對比了UNet和UNet++以及使用不一樣的預訓練編碼器的效果。網絡
介紹
語義分割是計算機視覺的一個問題,咱們的任務是使用圖像做爲輸入,爲圖像中的每一個像素分配一個類。在語義分割的狀況下,咱們不關心是否有同一個類的多個實例(對象),咱們只是用它們的類別來標記它們。有多種關於不一樣計算機視覺問題的介紹課程,但用一張圖片能夠總結不一樣的計算機視覺問題:架構
語義分割在生物醫學圖像分析中有着普遍的應用:x射線、MRI掃描、數字病理、顯微鏡、內窺鏡等。https://grand-challenge.org/challenges上有許多不一樣的有趣和重要的問題有待探索。
app
從技術角度來看,若是咱們考慮語義分割問題,對於N×M×3(假設咱們有一個RGB圖像)的圖像,咱們但願生成對應的映射N×M×k(其中k是類的數量)。有不少架構能夠解決這個問題,但在這裏我想談談兩個特定的架構,Unet和Unet++。dom
有許多關於Unet的評論,它如何永遠地改變了這個領域。它是一個統一的很是清晰的架構,由一個編碼器和一個解碼器組成,前者生成圖像的表示,後者使用該表示來構建分割。每一個空間分辨率的兩個映射鏈接在一塊兒(灰色箭頭),所以能夠將圖像的兩種不一樣表示組合在一塊兒。而且它成功了!編輯器
接下來是使用一個訓練好的編碼器。考慮圖像分類的問題,咱們試圖創建一個圖像的特徵表示,這樣不一樣的類在該特徵空間能夠被分開。咱們能夠(幾乎)使用任何CNN,並將其做爲一個編碼器,從編碼器中獲取特徵,並將其提供給咱們的解碼器。據我所知,Iglovikov & Shvets 使用了VGG11和resnet34分別爲Unet解碼器以生成更好的特徵和提升其性能。
函數
Unet++是最近對Unet體系結構的改進,它有多個跳躍鏈接。性能
根據論文, Unet++的表現彷佛優於原來的Unet。就像在Unet中同樣,這裏可使用多個編碼器(骨幹)來爲輸入圖像生成強特徵。
我應該使用哪一個編碼器?
這裏我想重點介紹Unet和Unet++,並比較它們使用不一樣的預訓練編碼器的性能。爲此,我選擇使用胸部x光數據集來分割肺部。這是一個二值分割,因此咱們應該給每一個像素分配一個類爲「1」的機率,而後咱們能夠二值化來製做一個掩碼。首先,讓咱們看看數據。
這些是很是大的圖像,一般是2000×2000像素,有很大的mask,從視覺上看,找到肺不是問題。使用segmentation_models_pytorch庫,咱們爲Unet和Unet++使用100+個不一樣的預訓練編碼器。咱們作了一個快速的pipeline來訓練模型,使用Catalyst (pytorch的另外一個庫,這能夠幫助你訓練模型,而沒必要編寫不少無聊的代碼)和Albumentations(幫助你應用不一樣的圖像轉換)。
-
定義數據集和加強。咱們將調整圖像大小爲256×256,並對訓練數據集應用一些大的加強。
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
class ChestXRayDataset(Dataset):
def __init__(
self,
images,
masks,
transforms):
self.images = images
self.masks = masks
self.transforms = transforms
def __len__(self):
return(len(self.images))
def __getitem__(self, idx):
"""Will load the mask, get random coordinates around/with the mask,
load the image by coordinates
"""
sample_image = imread(self.images[idx])
if len(sample_image.shape) == 3:
sample_image = sample_image[..., 0]
sample_image = np.expand_dims(sample_image, 2) / 255
sample_mask = imread(self.masks[idx]) / 255
if len(sample_mask.shape) == 3:
sample_mask = sample_mask[..., 0]
augmented = self.transforms(image=sample_image, mask=sample_mask)
sample_image = augmented['image']
sample_mask = augmented['mask']
sample_image = sample_image.transpose(2, 0, 1) # channels first
sample_mask = np.expand_dims(sample_mask, 0)
data = {'features': torch.from_numpy(sample_image.copy()).float(),
'mask': torch.from_numpy(sample_mask.copy()).float()}
return(data)
def get_valid_transforms(crop_size=256):
return A.Compose(
[
A.Resize(crop_size, crop_size),
],
p=1.0)
def light_training_transforms(crop_size=256):
return A.Compose([
A.RandomResizedCrop(height=crop_size, width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
], p=1.0),
])
def medium_training_transforms(crop_size=256):
return A.Compose([
A.RandomResizedCrop(height=crop_size, width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
], p=1.0),
A.OneOf(
[
A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
A.NoOp()
], p=1.0),
])
def heavy_training_transforms(crop_size=256):
return A.Compose([
A.RandomResizedCrop(height=crop_size, width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
], p=1.0),
A.ShiftScaleRotate(p=0.75),
A.OneOf(
[
A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
A.NoOp()
], p=1.0),
])
def get_training_trasnforms(transforms_type):
if transforms_type == 'light':
return(light_training_transforms())
elif transforms_type == 'medium':
return(medium_training_transforms())
elif transforms_type == 'heavy':
return(heavy_training_transforms())
else:
raise NotImplementedError("Not implemented transformation configuration")
-
定義模型和損失函數。這裏咱們使用帶有regnety_004編碼器的Unet++,並使用RAdam + Lookahed優化器使用DICE + BCE損失之和進行訓練。
import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from catalyst import dl, metrics, core, contrib, utils
import torch.nn as nn
from skimage.io import imread
import os
from sklearn.model_selection import train_test_split
from catalyst.dl import CriterionCallback, MetricAggregationCallback
encoder = 'timm-regnety_004'
model = smp.UnetPlusPlus(encoder, classes=1, in_channels=1)
#model.cuda()
learning_rate = 5e-3
encoder_learning_rate = 5e-3 / 10
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = contrib.nn.Lookahead(base_optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=10)
criterion = {
"dice": DiceLoss(mode='binary'),
"bce": nn.BCEWithLogitsLoss()
}
-
定義回調函數並訓練!
callbacks = [
# Each criterion is calculated separately.
CriterionCallback(
input_key="mask",
prefix="loss_dice",
criterion_key="dice"
),
CriterionCallback(
input_key="mask",
prefix="loss_bce",
criterion_key="bce"
),
# And only then we aggregate everything into one loss.
MetricAggregationCallback(
prefix="loss",
mode="weighted_sum",
metrics={
"loss_dice": 1.0,
"loss_bce": 0.8
},
),
# metrics
IoUMetricsCallback(
mode='binary',
input_key='mask',
)
]
runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
callbacks=callbacks,
logdir='../logs/xray_test_log',
num_epochs=100,
main_metric="loss",
minimize_metric=True,
verbose=True,
)
若是咱們用不一樣的編碼器對Unet和Unet++進行驗證,咱們能夠看到每一個訓練模型的驗證質量,並總結以下:
咱們注意到的第一件事是,在全部編碼器中,Unet++的性能彷佛都比Unet好。固然,有時這種差別並非很大,咱們不能說它們在統計上是否徹底不一樣 —— 咱們須要在多個folds上訓練,看看分數分佈,單點不能證實任何事情。第二,resnest200e顯示了最高的質量,同時仍然有合理的參數數量。有趣的是,若是咱們看看https://paperswithcode.com/task/semantic-segmentation,咱們會發現resnest200在一些基準測試中也是SOTA。
好的,可是讓咱們用Unet++和Unet使用resnest200e編碼器來比較不一樣的預測。
在某些個別狀況下,Unet++實際上比Unet更糟糕。但總的來講彷佛更好一些。
通常來講,對於分割網絡來講,這個數據集看起來是一個容易的任務。讓咱們在一個更難的任務上測試Unet++。爲此,我使用PanNuke數據集,這是一個帶標註的組織學數據集(205,343個標記核,19種不一樣的組織類型,5個核類)。數據已經被分割成3個folds。
咱們可使用相似的代碼在這個數據集上訓練Unet++模型,以下所示:
咱們在這裏看到了相同的模式 - resnest200e編碼器彷佛比其餘的性能更好。咱們能夠用兩個不一樣的模型(最好的是resnest200e編碼器,最差的是regnety_002)來可視化一些例子。
咱們能夠確定地說,這個數據集是一項更難的任務 —— 不只mask不夠精確,並且個別的核被分配到錯誤的類別。然而,使用resnest200e編碼器的Unet++仍然表現很好。
總結
這不是一個全面語義分割的指導,這更多的是一個想法,使用什麼來得到一個堅實的基線。有不少模型、FPN,DeepLabV3, Linknet與Unet有很大的不一樣,有許多Unet-like架構,例如,使用雙編碼器的Unet,MAnet,PraNet,U²-net — 有不少的型號供你選擇,其中一些可能在你的任務上表現的比較好,可是,一個堅實的基線能夠幫助你從正確的方向上開始。
![](http://static.javashuo.com/static/loading.gif)
英文原文:https://towardsdatascience.com/the-best-approach-to-semantic-segmentation-of-biomedical-images-bbe4fd78733f
請長按或掃描二維碼關注本公衆號
喜歡的話,請給我個在看吧!
本文分享自微信公衆號 - AI公園(AI_Paradise)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。