語義分割丨PSPNet源碼解析「訓練階段」

引言

以前一段時間在參與語義分割的項目,最近有時間了,正好把這段時間的所學總結一下。python

在代碼上,語義分割的框架會比目標檢測簡單不少,但其中也涉及了不少細節。在這篇文章中,我以PSPNet爲例,解讀一下語義分割框架的代碼。搞清楚一個框架後,再看別人的框架都是大同小異。git

工程來自https://github.com/speedinghzl/pytorch-segmentation-toolboxgithub

框架中一個很是重要的部分是evaluate.py,即測試階段。但因爲篇幅較長,我將另開一篇來闡述測試過程,本文關注訓練過程。數組

總體框架

pytorch-segmentation-toolbox
    |— dataset      數據集相關
        |— list         存放數據集的list
        |— datasets.py  數據集加載函數
    |— libs         存放pytorch的op如bn
    |— networks     存放網絡代碼
        |— deeplabv3.py
        |— pspnet.py
    |— utils        其餘函數
        |— criterion.py 損失計算
        |— encoding.py  顯存均勻
        |— loss.py      OHEM難例挖掘
        |— utils.py     colormap轉換
    |— evaluate.py  網絡測試
    |— run_local.sh 訓練腳本
    |— train.py     網絡訓練

train.py

網絡訓練主函數,主要操做有:markdown

  1. 傳入訓練參數;一般採用argparse庫,支持腳本傳入。
  2. 網絡訓練;包括定義網絡、加載模型、前向反向傳播、保存模型等。
  3. 將訓練狀況可視化;使用tensorboard繪製loss曲線。
import argparse

import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import pickle
import cv2
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import sys
import os
from tqdm import tqdm
import os.path as osp
from networks.pspnet import Res_Deeplab
from dataset.datasets import CSDataSet

import random
import timeit
import logging
from tensorboardX import SummaryWriter
from utils.utils import decode_labels, inv_preprocess, decode_predictions
from utils.criterion import CriterionDSN, CriterionOhemDSN
from utils.encoding import DataParallelModel, DataParallelCriterion

torch_ver = torch.__version__[:3]
if torch_ver == '0.3':
    from torch.autograd import Variable

start = timeit.default_timer()

#因爲使用了ImageNet的預訓練權重,所以須要在數據預處理過程減去ImageNet上的均值。
IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)

#這些超參數可在sh腳本中定義。
BATCH_SIZE = 8
DATA_DIRECTORY = 'cityscapes'
DATA_LIST_PATH = './dataset/list/cityscapes/train.lst'
IGNORE_LABEL = 255
INPUT_SIZE = '769,769'
LEARNING_RATE = 1e-2
MOMENTUM = 0.9
NUM_CLASSES = 19
NUM_STEPS = 40000
POWER = 0.9
RANDOM_SEED = 1234
RESTORE_FROM = './dataset/MS_DeepLab_resnet_pretrained_init.pth'
SAVE_NUM_IMAGES = 2
SAVE_PRED_EVERY = 10000
SNAPSHOT_DIR = './snapshots/'
WEIGHT_DECAY = 0.0005

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def get_arguments():
    """Parse all the arguments provided from the CLI.
    
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
    parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,       #Batch Size
                        help="Number of images sent to the network in one step.")
    parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,     #數據集地址
                        help="Path to the directory containing the PASCAL VOC dataset.")
    parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH,    #數據集清單
                        help="Path to the file listing the images in the dataset.")
    parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,   #忽略類別(未使用)
                        help="The index of the label to ignore during the training.")
    parser.add_argument("--input-size", type=str, default=INPUT_SIZE,       #輸入尺寸
                        help="Comma-separated string with height and width of images.")
    parser.add_argument("--is-training", action="store_true",               #是否訓練   若不傳入爲false
                        help="Whether to updates the running means and variances during the training.")
    parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,   #學習率
                        help="Base learning rate for training with polynomial decay.")
    parser.add_argument("--momentum", type=float, default=MOMENTUM,         #動量係數,用於優化參數
                        help="Momentum component of the optimiser.")
    parser.add_argument("--not-restore-last", action="store_true",          #是否存儲最後一層(未使用)
                        help="Whether to not restore last (FC) layers.")
    parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,     #類別數
                        help="Number of classes to predict (including background).")
    parser.add_argument("--start-iters", type=int, default=0,               #起始iter數
                        help="Number of classes to predict (including background).")
    parser.add_argument("--num-steps", type=int, default=NUM_STEPS,         #訓練步數   
                        help="Number of training steps.")
    parser.add_argument("--power", type=float, default=POWER,               #power係數,用於更新學習率
                        help="Decay parameter to compute the learning rate.")
    parser.add_argument("--random-mirror", action="store_true",             #數據加強 翻轉
                        help="Whether to randomly mirror the inputs during the training.")
    parser.add_argument("--random-scale", action="store_true",              #數據加強 多尺度
                        help="Whether to randomly scale the inputs during the training.")
    parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,     #隨機種子
                        help="Random seed to have reproducible results.")
    parser.add_argument("--restore-from", type=str, default=RESTORE_FROM,   #模型斷點續跑
                        help="Where restore model parameters from.")
    parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, #保存多少張圖片(未使用)
                        help="How many images to save.")
    parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, #每多少次保存一次斷點
                        help="Save summaries and checkpoint every often.")
    parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,       #模型保存位置
                        help="Where to save snapshots of the model.")
    parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,     #權重衰減係數,用於正則化
                        help="Regularisation parameter for L2-loss.")
    parser.add_argument("--gpu", type=str, default='None',                      #使用哪些GPU
                        help="choose gpu device.")
    parser.add_argument("--recurrence", type=int, default=1,                #循環次數(未使用)
                        help="choose the number of recurrence.")
    parser.add_argument("--ft", type=bool, default=False,                   #微調模型(未使用)
                        help="fine-tune the model with large input size.")

    parser.add_argument("--ohem", type=str2bool, default='False',           #難例挖掘
                        help="use hard negative mining")
    parser.add_argument("--ohem-thres", type=float, default=0.6,
                        help="choose the samples with correct probability underthe threshold.")
    parser.add_argument("--ohem-keep", type=int, default=200000,
                        help="choose the samples with correct probability underthe threshold.")
    return parser.parse_args()

args = get_arguments()  #加載參數

#poly學習策略
def lr_poly(base_lr, iter, max_iter, power):
    return base_lr*((1-float(iter)/max_iter)**(power))
            
#調整學習率
def adjust_learning_rate(optimizer, i_iter):
    """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
    lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power)
    optimizer.param_groups[0]['lr'] = lr
    return lr

#將BN設置爲測試狀態
def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

#設置BN動量
def set_bn_momentum(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1 or classname.find('InPlaceABN') != -1:
        m.momentum = 0.0003

#網絡訓練主函數
def main():
    """Create the model and start the training."""
    writer = SummaryWriter(args.snapshot_dir)   #定義SummaryWriter對象來可視化訓練狀況。
    
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    h, w = map(int, args.input_size.split(',')) #769, 769
    input_size = (h, w) #(769, 769)

    cudnn.enabled = True

    # Create network.
    deeplab = Res_Deeplab(num_classes=args.num_classes) #定義網絡
    print(deeplab)

    saved_state_dict = torch.load(args.restore_from)    #加載模型   saved_state_dict['conv1.weight'] = {Tensor}
    new_params = deeplab.state_dict().copy()    #模態字典,創建層與參數的映射關係   new_params['conv1.weight']={Tensor}
    for i in saved_state_dict:  #剔除預訓練模型中的全鏈接層部分
        #Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.')  #['conv1', 'weight', '2']
        # print i_parts
        # if not i_parts[1]=='layer5':
        if not i_parts[0]=='fc':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
    
    deeplab.load_state_dict(new_params) #剔除後,加載模態字典,完成模型載入
    #deeplab.load_state_dict(torch.load(args.restore_from)) #若無需剔除

    model = DataParallelModel(deeplab)  #多GPU並行處理
    model.train()   #設置訓練模式,在evaluate.py中是model.eval()
    model.float()
    # model.apply(set_bn_momentum)
    model.cuda()    #會將模型加載到0號gpu上並做爲主GPU,也可本身指定
    #model = model.cuda(device_ids[0])

    if args.ohem:   #是否採用難例挖掘
        criterion = CriterionOhemDSN(thresh=args.ohem_thres, min_kept=args.ohem_keep)
    else:
        criterion = CriterionDSN() #CriterionCrossEntropy()
    criterion = DataParallelCriterion(criterion)    #多GPU機器均衡負載
    criterion.cuda()    #優化器也放在gpu上
    
    cudnn.benchmark = True  #能夠提高一點訓練速度,沒有額外開銷,通常都會加

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    #數據加載,該部分見datasets.py
    trainloader = data.DataLoader(CSDataSet(args.data_dir, args.data_list, max_iters=args.num_steps*args.batch_size, crop_size=input_size, 
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), 
                    batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    #優化器
    optimizer = optim.SGD([{'params': filter(lambda p: p.requires_grad, deeplab.parameters()), 'lr': args.learning_rate }], 
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()   #清空上一步的殘餘更新參數值

    interp = nn.Upsample(size=input_size, mode='bilinear', align_corners=True)  #(未使用)

    for i_iter, batch in enumerate(trainloader):
        i_iter += args.start_iters  
        images, labels, _, _ = batch
        images = images.cuda()
        labels = labels.long().cuda()
        if torch_ver == "0.3":
            images = Variable(images)
            labels = Variable(labels)

        optimizer.zero_grad()   #清空上一步的殘餘更新參數值
        lr = adjust_learning_rate(optimizer, i_iter)    #調整學習率
        preds = model(images)   #[x, x_dsn]

        loss = criterion(preds, labels) #計算偏差
        loss.backward()     #偏差反向傳播
        optimizer.step()    #更新參數值

        #用以前定義的SummaryWriter對象在Tensorboard中繪製lr和loss曲線
        if i_iter % 100 == 0:
            writer.add_scalar('learning_rate', lr, i_iter)
            writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

        #是否將訓練中途的結果可視化
        # if i_iter % 5000 == 0:
        #     images_inv = inv_preprocess(images, args.save_num_images, IMG_MEAN)
        #     labels_colors = decode_labels(labels, args.save_num_images, args.num_classes)
        #     if isinstance(preds, list):
        #         preds = preds[0]
        #     preds_colors = decode_predictions(preds, args.save_num_images, args.num_classes)
        #     for index, (img, lab) in enumerate(zip(images_inv, labels_colors)):
        #         writer.add_image('Images/'+str(index), img, i_iter)
        #         writer.add_image('Labels/'+str(index), lab, i_iter)
        #         writer.add_image('preds/'+str(index), preds_colors[index], i_iter)

        print('iter = {} of {} completed, loss = {}'.format(i_iter, args.num_steps, loss.data.cpu().numpy()))

        if i_iter >= args.num_steps-1:  #保存最終模型
            print('save model ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(args.num_steps)+'.pth'))
            break

        if i_iter % args.save_pred_every == 0:  #每隔必定步數保存模型
            print('taking snapshot ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(i_iter)+'.pth'))   #僅保存學習到的參數
            #torch.save(deeplab, PATH)  #保存整個model及狀態

    end = timeit.default_timer()
    print(end-start,'seconds')

if __name__ == '__main__':
    main()

datasets.py

在pytorch中數據加載到模型的操做順序以下:網絡

  1. 建立一個Dataset對象,通常重載__len____getitem__方法。__len__返回數據集大小,__getitem__支持索引,以便Dataset[i]獲取第i個樣本。
  2. 建立一個DataLoader對象,將Dataset做爲參數傳入。
  3. 循環這個DataLoader對象,將img、label加載到模型中進行訓練。

這裏展現一個簡單的例子:app

dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for img, label in dataloader:

咱們還需在Dataset對象中定義數據預處理,這裏採用:框架

  1. 0.7-1.4倍的隨機尺度縮放dom

  2. 各通道減去ImageNet的均值
  3. 隨機crop下769x769大小
  4. 鏡像隨機翻轉ide

注意:爲了讓Image和Label對應,也要對Label做相應的預處理,具體過程詳見代碼。

import os
import os.path as osp
import numpy as np
import random
import collections
import torch
import torchvision
import cv2
from torch.utils import data

#Cityscapes數據集加載
#crop_size(769,769)、max_iters = num_steps * batch_size = 8 * 40000 = 320000
class CSDataSet(data.Dataset):
    def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255):
        self.root = root    #數據集地址
        self.list_path = list_path  #數據集列表
        self.crop_h, self.crop_w = crop_size    #剪裁尺寸
        self.scale = scale  #尺度
        self.ignore_label = ignore_label    #忽略類別
        self.mean = mean    #數據集各通道平均值
        self.is_mirror = mirror #是否鏡像
        # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
        self.img_ids = [i_id.strip().split() for i_id in open(list_path)]   #列表 存放每張圖像及其標籤在數據集中的地址
        if not max_iters==None: #訓練時根據max_iter數將列表翻倍    if max_iter=320000、len(trainset)=2975
        #每個iter訓練一張圖,要計算max_iter要訓練多少輪trainset
                self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))    # 2975 * (32000/2975) = 321300
        self.files = [] #用來放數據的列表
        # for split in ["train", "trainval", "val"]:
        for item in self.img_ids:   #遍歷每一張訓練樣本
            image_path, label_path = item   #圖像、標籤地址
            name = osp.splitext(osp.basename(label_path))[0]
            img_file = osp.join(self.root, image_path)
            label_file = osp.join(self.root, label_path)
            self.files.append({ #列表的每一項是一個字典
                "img": img_file,
                "label": label_file,
                "name": name            #aachen_000000_000019_leftImg8bit.png
            })
        #19類與官方給定類別的轉換
        self.id_to_trainid = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
                              3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
                              7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
                              14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
                              18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
                              28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
        print('{} images are loaded!'.format(len(self.img_ids)))

    def __len__(self):  #數據集長度
        return len(self.files)  #321300

    #生成不一樣尺度下的樣本和標籤
    def generate_scale_label(self, image, label):
        f_scale = 0.7 + random.randint(0, 14) / 10.0    # 0.7 + (0~1.4)
        image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_LINEAR)
        label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST)
        return image, label

    #實現類別數和trainId的相互轉換:如第19類對應trainId 33
    def id2trainId(self, label, reverse=False):
        label_copy = label.copy()
        if reverse: #trainId2id
            for v, k in self.id_to_trainid.items():
                label_copy[label == k] = v
        else:   #id2trainId
            for k, v in self.id_to_trainid.items():
                label_copy[label == k] = v
        return label_copy

    #返回一張樣本
    def __getitem__(self, index):
        datafiles = self.files[index]
        image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR)  #shape(1024,2048,3)
        label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE)    #shape(1024,2048)
        label = self.id2trainId(label)  #label圖像(-1~33) 轉化爲數組(0~19)
        size = image.shape  #[1024,2048,3]
        name = datafiles["name"]
        if self.scale:  #若採用多尺度
            image, label = self.generate_scale_label(image, label)
        image = np.asarray(image, np.float32)
        image -= self.mean  #減去均值
        img_h, img_w = label.shape  #1024, 2048
        pad_h = max(self.crop_h - img_h, 0) #max(769-1024, 0)
        pad_w = max(self.crop_w - img_w, 0) #max(769-2048, 0)
        if pad_h > 0 or pad_w > 0:  #若尺度縮放後的尺寸比crop_size尺寸小,則對邊界進行填充
            img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, 
                pad_w, cv2.BORDER_CONSTANT, 
                value=(0.0, 0.0, 0.0))
            label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, 
                pad_w, cv2.BORDER_CONSTANT,
                value=(self.ignore_label,))
        else:
            img_pad, label_pad = image, label

        img_h, img_w = label_pad.shape  #102四、2048
        h_off = random.randint(0, img_h - self.crop_h)  #生成隨機數如100
        w_off = random.randint(0, img_w - self.crop_w)  #20
        # roi = cv2.Rect(w_off, h_off, self.crop_w, self.crop_h);
        image = np.asarray(img_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32)   #任意扣下([100:100+769, 20:20+769])
        label = np.asarray(label_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32) #([100:100+769, 20:20+769])
        #image = image[:, :, ::-1]  # change to BGR
        image = image.transpose((2, 0, 1))  #shape(3, 769, 769)
        if self.is_mirror:  #鏡像隨機翻轉
            flip = np.random.choice(2) * 2 - 1  #flip = 1 or -1
            image = image[:, :, ::flip]
            label = label[:, ::flip]

        return image.copy(), label.copy(), np.array(size), name #image.shape(3, 769, 769)、label.shape(769, 769)

上面定義了一個Dataset對象CSDataSet,以後咱們在train.py中定義DataLoader對象trainloader,並將CSDataSet做爲參數傳入。

trainloader = data.DataLoader(CSDataSet(args.data_dir, args.data_list, max_iters=args.num_steps*args.batch_size, crop_size=input_size, 
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), 
                    batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

爲更清楚這些參數的含義,能夠參考一下DataLoader類的定義。

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset(Dataset): 傳入的數據集
        batch_size(int, optional): 每一個batch有多少個樣本
        shuffle(bool, optional): 在每一個epoch開始的時候,對數據進行從新排序
        sampler(Sampler, optional): 自定義從數據集中取樣本的策略,若是指定這個參數,那麼shuffle必須爲False
        batch_sampler(Sampler, optional): 與sampler相似,可是一次只返回一個batch的indices(索引),須要注意的是,一旦指定了這個參數,那麼batch_size,shuffle,sampler,drop_last就不能再製定了(互斥——Mutually exclusive)
        num_workers (int, optional): 這個參數決定了有幾個進程來處理data loading。0意味着全部的數據都會被load進主進程。(默認爲0)
        collate_fn (callable, optional): 將一個list的sample組成一個mini-batch的函數
        pin_memory (bool, optional): 若是設置爲True,那麼data loader將會在返回它們以前,將tensors拷貝到CUDA中的固定內存(CUDA pinned memory)中.

        drop_last (bool, optional): 若是設置爲True:這個是對最後的未完成的batch來講的,好比你的batch_size設置爲64,而一個epoch只有100個樣本,那麼訓練的時候後面的36個就被扔掉了…
        若是爲False(默認),那麼會繼續正常執行,只是最後的batch_size會小一點。

        timeout(numeric, optional): 若是是正數,代表等待從worker進程中收集一個batch等待的時間,若超出設定的時間尚未收集到,那就不收集這個內容了。這個numeric應老是大於等於0。默認爲0
        worker_init_fn (callable, optional): 每一個worker初始化函數 If not None, this will be called on each
        worker subprocess with the worker id (an int in [0, num_workers - 1]) as
        input, after seeding and before data loading. (default: None)

    .. note:: By default, each worker will have its PyTorch seed set to
              ``base_seed + worker_id``, where ``base_seed`` is a long generated
              by main process using its RNG. However, seeds for other libraies
              may be duplicated upon initializing workers (w.g., NumPy), causing
              each worker to return identical random numbers. (See
              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
              use ``torch.initial_seed()`` to access the PyTorch seed for each
              worker in :attr:`worker_init_fn`, and use it to set other seeds
              before data loading.

    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                 unpicklable object, e.g., a lambda function.
    """

    __initialized = False

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers  
        self.collate_fn = collate_fn    
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)  //將list打亂
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

    def __setattr__(self, attr, val):
        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
            raise ValueError('{} attribute should not be set after {} is '
                             'initialized'.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

    def __iter__(self):
        return _DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

pspnet.py

在pytorch中自定義網絡,集成nn.Module類並重載__init__(self)forward,分別定義網絡組成和前向傳播,這裏有一個簡單的例子。

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

下面先看一下PSPNet的論文介紹,網絡結構很是簡單,在ResNet以後接一個PPM模塊。

1559011591663

此外PSPNet還採用了輔助損失分支。

1559026591118

import torch.nn as nn
from torch.nn import functional as F
import math
import torch.utils.model_zoo as model_zoo
import torch
import numpy as np
from torch.autograd import Variable
affine_par = True
import functools

import sys, os

from libs import InPlaceABN, InPlaceABNSync
BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

#ResNet的Bottleneck
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False)
        self.bn2 = BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=False)
        self.relu_inplace = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.dilation = dilation
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual      
        out = self.relu_inplace(out)

        return out

#PPM模塊
class PSPModule(nn.Module):
    """
    Reference: 
        Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
    """
    def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
        super(PSPModule, self).__init__()

        self.stages = []
        self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes])
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=3, padding=1, dilation=1, bias=False),
            InPlaceABNSync(out_features),
            nn.Dropout2d(0.1)
            )

    def _make_stage(self, features, out_features, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
        bn = InPlaceABNSync(out_features)
        return nn.Sequential(prior, conv, bn)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats]
        bottle = self.bottleneck(torch.cat(priors, 1))
        return bottle

#PSPNet網絡總體
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 128
        super(ResNet, self).__init__()
        self.conv1 = conv3x3(3, 64, stride=2)
        self.bn1 = BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = conv3x3(64, 64)
        self.bn2 = BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=False)
        self.conv3 = conv3x3(64, 128)
        self.bn3 = BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=False)
        #
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.relu = nn.ReLU(inplace=False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,1,1))

        
        self.head = nn.Sequential(PSPModule(2048, 512),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True))

        #輔助損失
        self.dsn = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            InPlaceABNSync(512),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
            )

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm2d(planes * block.expansion,affine = affine_par))

        layers = []
        generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1
        layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid)))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))

        return nn.Sequential(*layers)

    def forward(self, x):   #(1,3,769,769)
        x = self.relu1(self.bn1(self.conv1(x))) #(1,64,385,385)
        x = self.relu2(self.bn2(self.conv2(x))) #(1,64,385,385)
        x = self.relu3(self.bn3(self.conv3(x))) #(1,128,385,385)
        x = self.maxpool(x) #(1,128,193,193)
        x = self.layer1(x)  #(1,256,97,97)
        x = self.layer2(x)  #(1,512,97,97)
        x = self.layer3(x)  #(1,1024,97,97)
        x_dsn = self.dsn(x) #(1,19,97,97)
        x = self.layer4(x)  #(1,2048,97,97)
        x = self.head(x)    #(1,19,769,769)
        return [x, x_dsn]

    def Res_Deeplab(num_classes=21):
    model = ResNet(Bottleneck,[3, 4, 23, 3], num_classes)
    return model

PSPNet輸入1x3x769x769,1爲BS、3爲RGB通道、769爲cropsize。並有兩個輸出1x19x97x97和1x19x769x769,19爲類別數,預測了每一個位置屬於各種的機率。(注意這裏還沒有softmax,機率之和不爲1)。

criterion.py

語義分割的損失函數主要是交叉熵。因爲採用了輔助損失,因此Loss應該包含兩部分。

\(total\_loss=\alpha \cdot loss1+\beta \cdot loss2\)

此外,這裏還定義了OHEM的損失計算,具體實現請看loss.py

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import numpy as np
from torch.nn import functional as F
from torch.autograd import Variable
from .loss import OhemCrossEntropy2d
import scipy.ndimage as nd

class CriterionDSN(nn.Module):
    '''
    DSN : We need to consider two supervision for the model.
    咱們須要考慮兩種損失
    '''
    def __init__(self, ignore_index=255, use_weight=True, reduce=True):
        super(CriterionDSN, self).__init__()
        self.ignore_index = ignore_index
        #交叉熵計算Loss,忽略了255類,而且對Loss取了平均
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)
        if not reduce:
            print("disabled the reduce.")

    #criterion(preds, labels)
    def forward(self, preds, target):
        h, w = target.size(1), target.size(2)   #769, 769

        scale_pred = F.upsample(input=preds[0], size=(h, w), mode='bilinear', align_corners=True)
        loss1 = self.criterion(scale_pred, target)

        scale_pred = F.upsample(input=preds[1], size=(h, w), mode='bilinear', align_corners=True)
        loss2 = self.criterion(scale_pred, target)

        return loss1 + loss2*0.4

#採用難例挖掘
class CriterionOhemDSN(nn.Module):
    '''
    DSN : We need to consider two supervision for the model.
    '''
    def __init__(self, ignore_index=255, thresh=0.7, min_kept=100000, use_weight=True, reduce=True):
        super(CriterionOhemDSN, self).__init__()
        self.ignore_index = ignore_index
        self.criterion1 = OhemCrossEntropy2d(ignore_index, thresh, min_kept)    #採用了新的計算方式
        self.criterion2 = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)

    def forward(self, preds, target):
        h, w = target.size(1), target.size(2)   #769, 769

        scale_pred = F.upsample(input=preds[0], size=(h, w), mode='bilinear', align_corners=True)
        loss1 = self.criterion1(scale_pred, target)

        scale_pred = F.upsample(input=preds[1], size=(h, w), mode='bilinear', align_corners=True)
        loss2 = self.criterion2(scale_pred, target)

        return loss1 + loss2*0.4

loss.py

OHEM目的是篩選出困難樣原本訓練模型,從而提高性能,其有兩個超參數:\(\theta\)\(K\)

困難樣本被定義爲預測機率小於$\theta \(的像素,而且每一個*minibatch*至少保證\)K$個困難樣本。

1559028836177

具體實現是將pspnet的輸出通過softmax,而後進行兩次篩選。第一次篩選基於label的有效區域(非255),predict上255對應的區域將不歸入loss的計算。經第一次篩選,將label中對應predict機率大於0.7的區域也置爲255。最後只有剩餘區域將參與loss的計算。

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import scipy.ndimage as nd


class OhemCrossEntropy2d(nn.Module):

    def __init__(self, ignore_label=255, thresh=0.7, min_kept=100000, factor=8):
        super(OhemCrossEntropy2d, self).__init__()
        self.ignore_label = ignore_label    #忽略類別255
        self.thresh = float(thresh)         #閾值0.7
        # self.min_kept_ratio = float(min_kept_ratio)
        self.min_kept = int(min_kept)       #
        self.factor = factor
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)

    #尋找閾值
    #np_predict.shape(1, 19, 769, 769)、np_target.shape(1, 769, 769)
    """
    閾值的選取主要基於min_kept,用第min_kept個的機率來肯定。
    且返回的閾值只能 ≥ thresh。
    """
    def find_threshold(self, np_predict, np_target):
        # downsample 1/8
        factor = self.factor    #8
        predict = nd.zoom(np_predict, (1.0, 1.0, 1.0/factor, 1.0/factor), order=1)  #雙線性插值  shape(1, 19, 96, 96)
        target = nd.zoom(np_target, (1.0, 1.0/factor, 1.0/factor), order=0) #最近臨插值  shape(1, 96, 96)

        n, c, h, w = predict.shape  #1, 19, 96, 96
        min_kept = self.min_kept // (factor*factor) #int(self.min_kept_ratio * n * h * w)   #100000/64 = 1562

        input_label = target.ravel().astype(np.int32)   #將多維數組轉化爲一維 shape(9216, )
        input_prob = np.rollaxis(predict, 1).reshape((c, -1))   #軸1滾動到軸0、shape(19, 9216)

        valid_flag = input_label != self.ignore_label   #label中有效位置(9216, )
        valid_inds = np.where(valid_flag)[0]    #(9013, )
        label = input_label[valid_flag] #有效label(9013, )
        num_valid = valid_flag.sum()    #9013
        if min_kept >= num_valid:   #1562 >= 9013
            threshold = 1.0
        elif num_valid > 0: #9013 > 0
            prob = input_prob[:,valid_flag] #(19, 9013) #找出有效區域對應的prob
            pred = prob[label, np.arange(len(label), dtype=np.int32)]   #???    shape(9013, )
            threshold = self.thresh     #0.7
            if min_kept > 0:    #1562>0
                k_th = min(len(pred), min_kept)-1   #min(9013, 1562)-1 = 1561
                new_array = np.partition(pred, k_th)    #排序並分紅兩個區,小於第1561個及大於第1561個
                new_threshold = new_array[k_th]     #第1561對應的pred 0.03323581
                if new_threshold > self.thresh:     #返回的閾值只能≥0.7
                    threshold = new_threshold
        return threshold

    #生成新的labels
    #predict.shape(1, 19, 97, 97)、target.shape(1, 97, 97)
    """
    主要思路
        1先經過find_threshold找到一個合適的閾值如0.7
        2一次篩選出不爲255的區域
        3再從中二次篩選找出對應預測值小於0.7的區域
        4從新生成一個label,label把預測值大於0.7和本來爲255的位置 都置爲255
    """
    def generate_new_target(self, predict, target):
        np_predict = predict.data.cpu().numpy() #shape(1, 19, 769, 769)
        np_target = target.data.cpu().numpy()   #shape(1, 769, 769)
        n, c, h, w = np_predict.shape   #1, 19, 769, 769

        threshold = self.find_threshold(np_predict, np_target)  #尋找閾值0.7

        input_label = np_target.ravel().astype(np.int32)    #shape(591361, )
        input_prob = np.rollaxis(np_predict, 1).reshape((c, -1))    #(19, 591361)

        valid_flag = input_label != self.ignore_label   #label中有效位置(591361, )
        valid_inds = np.where(valid_flag)[0]    #(579029, )
        label = input_label[valid_flag] #一次篩選:不爲255的label(579029, )
        num_valid = valid_flag.sum()    #579029

        if num_valid > 0:
            prob = input_prob[:,valid_flag] #(19, 579029)
            pred = prob[label, np.arange(len(label), dtype=np.int32)]   #不明白這一步的操做??? (579029, )
            kept_flag = pred <= threshold   #二次篩選:在255中找出pred≤0.7的位置
            valid_inds = valid_inds[kept_flag]  #shape(579029, )
            print('Labels: {} {}'.format(len(valid_inds), threshold))

        label = input_label[valid_inds].copy()  #從原label上扣下來shape(579029, )
        input_label.fill(self.ignore_label) #shape(591361, )每一個值都爲255
        input_label[valid_inds] = label #把二次篩選後有效區域的對應位置爲label,其他爲255
        new_target = torch.from_numpy(input_label.reshape(target.size())).long().cuda(target.get_device())  #shape(1, 769, 769)

        return new_target   #shape(1, 769, 769)


    def forward(self, predict, target, weight=None):
        """
            Args:
                predict:(n, c, h, w)    (1, 19, 97, 97)
                target:(n, h, w)        (1, 97, 97)
                weight (Tensor, optional): a manual rescaling weight given to each class.
                                           If given, has to be a Tensor of size "nclasses"
        """
        assert not target.requires_grad

        input_prob = F.softmax(predict, 1)  #在channel上進行一次softmax,獲得機率
        target = self.generate_new_target(input_prob, target)   #生成新labels
        return self.criterion(predict, target)

參考

Zhao H, Shi J, Qi X, et al. Pyramid scene parsing network[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 2881-2890.

Yuan Y, Wang J. Ocnet: Object context network for scene parsing[J]. arXiv preprint arXiv:1809.00916, 2018.

相關文章
相關標籤/搜索