DataWhale街景字符編碼識別項目-模型構建

導入相關庫

import os
from glob import glob
import torch as t
# 設置隨機種子是爲了保證結果的可重複性
t.random.manual_seed(0)
t.cuda.manual_seed_all(0)
# Benchmark模式會提高計算速度,可是因爲計算中有隨機性,每次網絡前饋結果略有差別
t.backends.cudnn.benchmark = True
# 避免上一句所帶來的波動
t.backends.cudnn.deterministic = True

from PIL import Image
import torch.nn as nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, MultiStepLR, CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patch 
import torch.nn.functional as F
import json
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.resnet import resnet18, resnet34
from torchsummary import summary
%matplotlib inline

設置網絡配置參數

class Config:
    
    batch_size = 16
    
    # 初始學習率
    lr = 1e-2
    
    # 動量
    momentum = 0.9
    
    # 衰減係數
    weights_decay = 1e-5
    
    class_num = 11
    
    # 每隔多少個epoch進行一次網絡評估
    eval_interval = 1
    
    # 每隔多少個epoch保存一次模型
    checkpoint_interval = 1
    
    # 每隔多少個iteration進行進度條更新或輸出log
    print_interval = 50
    
    # 模型保存路徑
    checkpoints = 'drive/My Drive/Data/Datawhale-DigitsRecognition/checkpoints/'
    
    # 預訓練模型加載路徑
    pretrained = '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/checkpoints/epoch-32_acc-0.67.pth'
    
    # 開始訓練的epoch
    start_epoch = 0
    
    # 一共訓練的epoch數目
    epoches = 50
    
    # label smooth參數,爲1表示不使用label smooth
    smooth = 0.1
    
    # 隨機擦除的機率, 爲0表示不擦除
    erase_prob = 0.5
    
config = Config()

構建網絡模型

一般而言,在構建Baseline時,會選擇參數儘量少,模型複雜度較低的輕量級網絡做爲backbone。若是能夠work,後期纔會用更復雜的backbone來替換它。python

這裏選用的是MobileNet V2做爲backbone, 來搭建一個分類網絡git

class DigitsMobilenet(nn.Module):

    def __init__(self, class_num=11):
        super(DigitsMobilenet, self).__init__()
    
        self.net = mobilenet_v2(pretrained=True)
        self.net.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1))
            )
        self.fc1 = nn.Linear(1280, class_num)
        self.fc2 = nn.Linear(1280, class_num)
        self.fc3 = nn.Linear(1280, class_num)
        self.fc4 = nn.Linear(1280, class_num)
        self.fc5 = nn.Linear(1280, class_num)
    
    def forward(self, img):
        """
        Params:
            img(tensor): shape [N, C, H, W]
            
        Returns:
            fc1(tensor): 表明第1個字符的presentation
            fc2(tensor): 表明第2個字符的presentation
            fc3(tensor): 表明第3個字符的presentation
            fc4(tensor): 表明第4個字符的presentation
            fc5(tensor): 表明第5個字符的presentation
        
        """
        features = self.net(img).view(-1, 1280)
    
        fc1 = self.fc1(features)
        fc2 = self.fc2(features)
        fc3 = self.fc3(features)
        fc4 = self.fc4(features)
        fc5 = self.fc5(features)
    
        return fc1, fc2, fc3, fc4, fc5
    
    
class DigitsResnet18(nn.Module):

    def __init__(self, class_num=11):
        super(DigitsMobilenet, self).__init__()
        self.net = resnet18(pretrained=True)
        
        # nn.Identity表示空層, 輸入等於輸出
        self.net.fc = nn.Identity()

        self.fc1 = nn.Linear(512, class_num)
        self.fc2 = nn.Linear(512, class_num)
        self.fc3 = nn.Linear(512, class_num)
        self.fc4 = nn.Linear(512, class_num)
        self.fc5 = nn.Linear(512, class_num)
    
    def forward(self, img):
        features = self.net(img).squeeze()
    
        fc1 = self.fc1(features)
        fc2 = self.fc2(features)
        fc3 = self.fc3(features)
        fc4 = self.fc4(features)
        fc5 = self.fc5(features)
    
        return fc1, fc2, fc3, fc4, fc5

構建訓練模塊

這裏使用了幾個Tricksgithub

  • Label Smooth標籤平滑
    標籤平滑是一種正則化技術,避免因爲數據量小致使的過擬合。
    Label smooth的公式以下,ε表示平滑度(實驗中設置爲0.1)C表示多分類的類別數,Pi表示軟化後的標籤機率。
    $$P_i=\begin{cases} 1-\epsilon \quad if(i=y)\\\frac{\epsilon}{C-1}\quad if(i\neq y) \end{cases}$$

    好比一個label的one-hot 編碼向量爲[0, 1, 0, 0], 通過label smooth以後的one-hot 編碼向量變爲[0.033, 0.9, 0.033, 0.033]。json

  • 餘弦衰減+warmup
    一般而言,剛開始梯度是極其不穩定的,所以應該使用較小的學習率先train幾個迭代次數,而後將學習率恢復到初始學習率,開始正常訓練。
    warmup在前n(n設爲10)次迭代過程當中,線性調整學習率到達初始學習率.必定程度上保證了訓練的穩定性,而且能夠更好的收斂到極小值。
    而餘弦衰減調整策略則能夠很好的跳出局部極小值,有更大的可能獲得更優的局部極小值。
    以下圖所示,分別表示warmup和餘弦衰減策略下的學習率曲線
    image.png
# ----------------------------------- LabelSmoothEntropy ----------------------------------- #
class LabelSmoothEntropy(nn.Module):
    def __init__(self, smooth=0.1, class_weights=None, size_average='mean'):
        super(LabelSmoothEntropy, self).__init__()
        self.size_average = size_average
        self.smooth = smooth
    
        self.class_weights = class_weights
        
    
    def forward(self, preds, targets):
    
        lb_pos, lb_neg = 1 - self.smooth, self.smooth / (preds.shape[0] - 1)
    
        smoothed_lb = t.zeros_like(preds).fill_(lb_neg).scatter_(1, targets[:, None], lb_pos)
    
        log_soft = F.log_softmax(preds)
    
        if self.class_weights is not None:
            loss = -log_soft * smoothed_lb * self.class_weights[None, :]
    
        else:
            loss = -log_soft * smoothed_lb
    
        loss = loss.sum(1)
        if self.size_average == 'mean':
            return loss.mean()
    
        elif self.size_average == 'sum':
            return loss.sum()
        else:
            raise NotImplementedError
    
  
  
class Trainer:

    def __init__(self):
    
        self.device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
    
        self.train_set = DigitsDataset(data_dir['train_data'], data_dir['train_label'])
        self.train_loader = DataLoader(self.train_set, batch_size=config.batch_size, num_workers=8, pin_memory=True, drop_last=True)
        
    
        self.val_loader = DataLoader(DigitsDataset(data_dir['val_data'], data_dir['val_label'], aug=False), batch_size=config.batch_size,\
                        num_workers=8, pin_memory=True, drop_last=True)
    
        self.model = DigitsMobilenet(config.class_num).to(self.device)
    
        # 使用Label Smooth
        self.criterion = LabelSmoothEntropy().to(self.device)
        
        self.optimizer = SGD(self.model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weights_decay, nesterov=True)
        
        # 使用餘弦衰減學習率調整策略
        self.lr_scheduler = CosineAnnealingWarmRestarts(self.optimizer, 10, 2, eta_min=10e-4)
        # self.lr_scheduler = (self.optimizer, [10, 20, 30], 0.5)
        self.best_acc = 0
    
        if config.pretrained is not None:
            self.load_model(config.pretrained)
            # print('Load model from %s'%config.pretrained)
            acc = self.eval()
            self.best_acc = acc
            print('Load model from %s, Eval Acc: %.2f'%(config.pretrained, acc * 100))
    

    def train(self):
        for epoch in range(config.start_epoch, config.epoches):
        self.train_epoch(epoch)
        if (epoch + 1) % config.eval_interval == 0:
            print('Start Evaluation')
            acc = self.eval()
    
            if acc > self.best_acc:
            os.makedirs(config.checkpoints, exist_ok=True)
            save_path = config.checkpoints+'epoch-%d_acc-%.2f.pth'%(epoch+1, acc)
            self.save_model(save_path)
            print('%s saved successfully...'%save_path)
            self.best_acc = acc
    
    def train_epoch(self, epoch):
        total_loss = 0
        corrects = 0
        tbar = tqdm(self.train_loader)
        self.model.train()
        for i, (img, label) in enumerate(tbar):
        img = img.to(self.device)
        label = label.to(self.device)
        self.optimizer.zero_grad()
        pred = self.model(img)
        loss = self.criterion(pred[0], label[:, 0]) + \
            self.criterion(pred[1], label[:, 1]) + \
            self.criterion(pred[2], label[:, 2]) + \
            self.criterion(pred[3], label[:, 3]) + \
            self.criterion(pred[4], label[:, 4])
        total_loss += loss.item()
        loss.backward()
        self.optimizer.step()
        temp = t.stack([\
                pred[0].argmax(1) == label[:, 0], \
                pred[1].argmax(1) == label[:, 1], \
                pred[2].argmax(1) == label[:, 2], \
                pred[3].argmax(1) == label[:, 3], \
                pred[4].argmax(1) == label[:, 4]\
            ], dim=1)

        # 只有預測的數字所有正確纔算正確
        corrects += t.all(temp, dim=1).sum().item()
        if (i + 1) % config.print_interval == 0:
            self.lr_scheduler.step()
            tbar.set_description('loss: %.3f, acc: %.3f'%(loss/(i+1), corrects*100/((i + 1) * config.batch_size)))
    
    def eval(self):
        self.model.eval()
        corrects = 0
        with t.no_grad():
        tbar = tqdm(self.val_loader)
        for i, (img, label) in enumerate(tbar):
            img = img.to(self.device)
            label = label.to(self.device)
            pred = self.model(img)
    
            temp = t.stack([
                pred[0].argmax(1) == label[:, 0], \
                pred[1].argmax(1) == label[:, 1], \
                pred[2].argmax(1) == label[:, 2], \
                pred[3].argmax(1) == label[:, 3], \
                pred[4].argmax(1) == label[:, 4]\
            ], dim=1)
    
            corrects += t.all(temp, dim=1).sum().item()
            tbar.set_description('Val Acc: %.2f'%(corrects * 100 /((i+1)*config.batch_size)))
        self.model.train()
        return corrects / (len(self.val_loader) * config.batch_size)
    
    def save_model(self, save_path, save_opt=False, save_config=False):
        # 保存模型
        dicts = {}
        dicts['model'] = self.model.state_dict()
        if save_opt:
            dicts['opt'] = self.optimizer.state_dict()
    
        if save_config:
            dicts['config'] = {s: config.__getattribute__(s) for s in dir(config) if not s.startswith('_')}
    
        t.save(dicts, save_path)
    
    def load_model(self, load_path, save_opt=False, save_config=False):
        # 加載模型
        dicts = t.load(load_path)
        self.model.load_state_dict(dicts['model'])
    
        if save_opt:
            self.optimizer.load_state_dict(dicts['opt'])
    
        if save_config:
            for k, v in dicts['config'].items():
                config.__setattr__(k, v)

總結

總的來講,我的以爲用分類的思想仍是挺新穎的,剛開始我都沒想過要用分類來作。若是分類模型就能搞定,那何須用目標檢測來幹呢。固然,針對競賽而言,目標檢測效果應該會更好。網絡

這部份內容和以前的內容是高度相關的,這部分用到了以前的代碼。
代碼放在個人gihub倉庫,歡迎Star。dom

全部數據我也經過雲盤共享,這是地址學習

ok,暫時就這樣了google

相關文章
相關標籤/搜索