PyTorch | 項目結構解析

在學習和使用深度學習框架時,復現現有項目代碼是必經之路,也能加深對理論知識的理解,提升動手能力。本文參照相關博客整理項目經常使用組織方式,以及每部分功能,幫助更好的理解復現項目流程,文末提供分類示例項目。python

1 項目組織

在作深度學習實驗或項目時,爲了獲得最優的模型結果,中間每每須要不少次的嘗試和修改。通常項目都包含如下幾個部分:git

  • 模型定義
  • 數據處理和加載
  • 訓練模型(Train&Validate)
  • 訓練過程的可視化
  • 測試(Test/Inference)

另外程序在組織過程當中還應該知足如下幾個要求:github

  • 模型需具備高度可配置性,便於修改參數、修改模型,反覆實驗
  • 代碼應具備良好的組織結構,令人一目瞭然
  • 代碼應具備良好的說明,使其餘人可以理解

2 項目結構

- checkpoints/: 用於保存訓練好的模型,可以使程序在異常退出後仍能從新載入模型,恢復訓練
- data/:數據相關操做,包括數據預處理、dataset實現等
- models/:模型定義,能夠有多個模型,例如上面的AlexNet和ResNet34,一個模型對應一個文件
- utils/:可能用到的工具函數,在本次實驗中主要是封裝了可視化工具
- config.py:配置文件,全部可配置的變量都集中在此,並提供默認值
- main.py:主文件,訓練和測試程序的入口,可經過不一樣的命令來指定不一樣的操做和參數
- requirements.txt:程序依賴的第三方庫
- README.md:提供程序的必要說明網絡

3 解析

3.1 __init__
- __init__ 能夠爲空,也能夠定義包的屬性和方法,但必須存在,其餘程序才能從這個目錄中讀取模塊和函數框架

3.2 數據加載
使用Dataset提供數據集的封裝,再使用Dataloader實現數據並行加載。dom

- def __init__(self..)
獲取圖片地址,並根據訓練、驗證和測試劃分數據
- def __getitem__(self, index):
返回圖片的數據和label
- def __len__(self):
返回數據集數量ide

train_dataset = DogCat(opt.train_data_root, train=True)
trainloader = DataLoader(train_dataset,
batch_size = opt.batch_size,
shuffle = True,
num_workers = opt.num_workers)

for ii, (data, label) in enumerate(trainloader):
train()

3.3 模型定義
型的定義主要保存在models/目錄下,其中BasicModule是對nn.Module的簡易封裝,提供快速加載和保存模型的接口。
nn.Module主要包括save和load兩個方法函數

from models import AlexNet工具

關於模型定義:
- 儘可能使用nn.Sequential(好比AlexNet)
- 將常常使用的結構封裝成子Module(好比GoogLeNet的Inception結構,ResNet的Residual Block結構)
- 將重複且有規律性的結構,用函數生成(好比VGG的多種變體,ResNet多種變體都是由多個重複卷積層組成)學習

3.4 工具函數
可能會用到一些helper方法,這些方法能夠統一放在utils/文件夾下,須要使用時再引入。在本例中主要是封裝了可視化工具visdom的一些操做,

3.5 配置文件
可配置的參數主要包括:

數據集參數(文件路徑、batch_size等)
訓練參數(學習率、訓練epoch等)
模型參數

在實際使用時,並不須要每次都修改config.py,只須要經過命令行傳入所需參數,覆蓋默認配置便可。

3.6 main函數
提到了fire
main中包括train、val、test、help等

訓練的主要步驟以下:

  • 定義網絡
  • 定義數據
  • 定義損失函數和優化器
  • 計算重要指標
  • 開始訓練
  • 訓練網絡
  • 可視化各類指標
  • 計算在驗證集上的指標

4 示例分類代碼

#coding:utf8
from config import opt
import os
import torch as t
import models
from data.dataset import DogCat
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchnet import meter
from utils.visualize import Visualizer
from tqdm import tqdm

def test(**kwargs):
    opt.parse(kwargs)
    import ipdb;
    ipdb.set_trace()
    # configure model
    model = getattr(models, opt.model)().eval()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu: model.cuda()

    # data
    train_data = DogCat(opt.test_data_root,test=True)
    test_dataloader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)
    results = []
    for ii,(data,path) in enumerate(test_dataloader):
        input = t.autograd.Variable(data,volatile = True)
        if opt.use_gpu: input = input.cuda()
        score = model(input)
        probability = t.nn.functional.softmax(score)[:,0].data.tolist()
        # label = score.max(dim = 1)[1].data.tolist()
        
        batch_results = [(path_,probability_) for path_,probability_ in zip(path,probability) ]

        results += batch_results
    write_csv(results,opt.result_file)

    return results

def write_csv(results,file_name):
    import csv
    with open(file_name,'w') as f:
        writer = csv.writer(f)
        writer.writerow(['id','label'])
        writer.writerows(results)
    
def train(**kwargs):
    opt.parse(kwargs)
    vis = Visualizer(opt.env)

    # step1: configure model
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu: model.cuda()

    # step2: data
    train_data = DogCat(opt.train_data_root,train=True)
    val_data = DogCat(opt.train_data_root,train=False)
    train_dataloader = DataLoader(train_data,opt.batch_size,
                        shuffle=True,num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,opt.batch_size,
                        shuffle=False,num_workers=opt.num_workers)
    
    # step3: criterion and optimizer
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay)
        
    # step4: meters
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e100

    # train
    for epoch in range(opt.max_epoch):
        
        loss_meter.reset()
        confusion_matrix.reset()

        for ii,(data,label) in tqdm(enumerate(train_dataloader),total=len(train_data)):

            # train model 
            input = Variable(data)
            target = Variable(label)
            if opt.use_gpu:
                input = input.cuda()
                target = target.cuda()

            optimizer.zero_grad()
            score = model(input)
            loss = criterion(score,target)
            loss.backward()
            optimizer.step()
            
            
            # meters update and visualize
            loss_meter.add(loss.data[0])
            confusion_matrix.add(score.data, target.data)

            if ii%opt.print_freq==opt.print_freq-1:
                vis.plot('loss', loss_meter.value()[0])
                
                # 進入debug模式
                if os.path.exists(opt.debug_file):
                    import ipdb;
                    ipdb.set_trace()


        model.save()

        # validate and visualize
        val_cm,val_accuracy = val(model,val_dataloader)

        vis.plot('val_accuracy',val_accuracy)
        vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
                    epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr))
        
        # update learning rate
        if loss_meter.value()[0] > previous_loss:          
            lr = lr * opt.lr_decay
            # 第二種下降學習率的方法:不會有moment等信息的丟失
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        

        previous_loss = loss_meter.value()[0]

def val(model,dataloader):
    '''
    計算模型在驗證集上的準確率等信息
    '''
    model.eval()
    confusion_matrix = meter.ConfusionMeter(2)
    for ii, data in enumerate(dataloader):
        input, label = data
        val_input = Variable(input, volatile=True)
        val_label = Variable(label.type(t.LongTensor), volatile=True)
        if opt.use_gpu:
            val_input = val_input.cuda()
            val_label = val_label.cuda()
        score = model(val_input)
        confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor))

    model.train()
    cm_value = confusion_matrix.value()
    accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
    return confusion_matrix, accuracy

def help():
    '''
    打印幫助的信息: python file.py help
    '''
    
    print('''
    usage : python file.py <function> [--args=value]
    <function> := train | test | help
    example: 
            python {0} train --env='env0701' --lr=0.01
            python {0} test --dataset='path/to/dataset/root/'
            python {0} help
    avaiable args:'''.format(__file__))

    from inspect import getsource
    source = (getsource(opt.__class__))
    print(source)

if __name__=='__main__':
    import fire
    fire.Fire()
View Code

 

參考:https://github.com/chenyuntc/pytorch-best-practice/blob/master/PyTorch%E5%AE%9E%E6%88%98%E6%8C%87%E5%8D%97.md 

相關文章
相關標籤/搜索