【深度學習框架】使用PyTorch進行數據處理

  在深度學習中,數據的處理對於神經網絡的訓練來講十分重要,良好的數據(包括圖像、文本、語音等)處理不只能夠加速模型的訓練,同時也直接關係到模型的效果。本文以處理圖像數據爲例,記錄一些使用PyTorch進行圖像預處理和數據加載的方法網絡


1、數據的加載

  在PyTorch中,數據加載須要自定義數據集類,並用此類來實例化數據對象,實現自定義的數據集須要繼承torch.utils.data包中的Dataset類
  在繼承Dataset實現本身的類時,須要實現如下兩個Python魔法方法:dom

  • __getitem__(index): 返回一個樣本數據,當使用obj[index]時實際就是在調用obj.__getitem__(index)
  • __len__():返回樣本的數量,當使用len(obj)時實際就是在調用obj.__len__()

  例如,以貓狗大戰的二分類數據集爲例,其加載過程以下:
工具

import os
import torch as t
from torch.utils import data
from PIL import Image
import numpy as np

class dogCat(data.Dataset):
    def __init__(self,root): # root爲數據存放目錄
        imgs = os.listdir(root) #列出當前路徑下全部的文件
        self.imgs = [os.path.join(root,img) for img in imgs] # 全部圖片的路徑
        #print(self.imgs)

    """返回一個樣本數據"""
    def __getitem__(self, item): 
        img_path = self.imgs[item] # 第item張圖片的路徑
        #dog 1 cat 0
        label = 1 if 'dog' in img_path.split('\\')[-1] else 0 # 獲取標籤信息
        #print(label)
        pil_img = Image.open(img_path) #讀入圖片
        print(type(pil_img))
        array = np.asarray(pil_img) # 轉爲numpy.array類型
        data = t.from_numpy(array) # 轉爲tensor類型
        return data,label #返回圖片對應的tensor及其標籤

    """樣本的數量"""
    def __len__(self):
        return len(self.imgs)

if __name__ == '__main__':
    dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train') #數據集對象
    data,label = dogcat[0] # 返回第0張圖片的信息
    print(data.size())
    print(label)
    print(len(dogcat))

2、計算機視覺工具包:torchvision

  對於圖像數據來講,以上的數據加載時不完善的,由於只是將圖片讀入,而沒有進行相關的處理,如每張圖片的大小和形狀,樣本的數值歸一化等等。
  爲了解決這一問題,PyTorch開發了一個視覺工具包torchvision,這個包獨立於torch,須要經過pip install torchvision來單獨安裝。
  torchvision有三個部分組成:學習

  • models提供各類經典的網絡結構和預訓練好的模型,如AlexNet、VGG、ResNet、Inception等
from torchvision import models
from torch import nn
resnet34 = models.resnet34(pretrained=True,num_classes=1000) # 加載預訓練模型
resnet34.fc=nn.Linear(512,10) # 修改全鏈接層爲10分類
  • datasets提供了經常使用的數據集,如MNIST、CIFAR10/100、ImageNet、COCO等
from torchvision import datasets
dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)

  除了經常使用數據集外,須要特別注意的是ImageFolder,ImageFolder假設全部的文件按文件夾存放,每一個文件夾下面存儲同一類的圖片,文件夾的名字爲這一類別的名字。這是咱們常常用到的一種數據組織形式。code

# 使用方法:
ImageFolder(root,transform=None,target_transform=None,loader=default_loader)
# 參數:文件夾路徑,對圖像作什麼樣的轉換,對標籤作什麼樣的轉換,如何加載圖片

from torchvision.datasets import ImageFolder
dataset = ImageFolder('data\\')
print(dataset.class_to_idx) # class_to_idx ,label和id的對應關係,從0開始
print(dataset.imgs) # 數據和標籤對應
  • transforms: 提供經常使用的數據預處理操做,主要是對Tensor和PIL Image對象的處理操做

  對PIL Image的操做:Resize、CenterCrop、RandomCrop、RandomsizedCrop、Pad、ToTensor等。orm

  對Tensor的操做:Normalize、ToPILImage等。對象

  若是要進行多個操做,能夠經過transforms.Compose([])將操做拼接起來。可是須要注意的是須要首先構建轉換操做,而後再執行轉換操做。繼承

import os
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms as T

transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])  # 構建轉換操做

class dogCat(data.Dataset):
    def __init__(self,root,transforms):
        imgs = os.listdir(root)
        #print(imgs)
        self.imgs = [os.path.join(root,img) for img in imgs]
        #print(self.imgs)
        self.transforms = transforms

    def __getitem__(self, item):
        img_path = self.imgs[item]
        #dog 1 cat 0
        label = 1 if 'dog' in img_path.split('\\')[-1] else 0
        #print(label)
        pil_img = Image.open(img_path)
        if self.transforms:
            pil_img = self.transforms(pil_img)  #執行準換操做
        return pil_img,label,item

    def __len__(self):
        return len(self.imgs)

3、使用DataLoader進行數據再處理

  經過上述描述,咱們經過自定義數據集類,使用視覺工具包進行圖像的轉換等操做,最終獲得的是一個dataset的數據集對象,使用此對象能夠一次返回一個樣本。
  可是,咱們應該清楚:訓練神經網絡時,通常採用的是小批量的梯度降低,所以咱們是對一批數據進行處理,也就是一個batch,同時,數據還須要進行打亂(shuffle)和並行加速等。PyTorch提供了DataLoader來實現這些功能。
  DataLoader定義以下:進程

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False)

  參數含義以下:圖片

  • dataset:加載的數據集
  • batch_zize: 批大小
  • shuffle: 是否將數據打亂
  • sampler:樣本抽樣,經常使用的有隨機採樣RandomSampler,shuffle=True時自動調用隨機採樣,默認是順序採樣,還有一個經常使用的是:WeightedRandomSampler,按照樣本的權重進行採樣。
  • num_workers: 使用的進程數,0表明不使用多進程。
  • collate_fn: 拼接方式。
  • pin_memory: 是否將數據保存在pin memory區。
  • drop_last: 是否將多出來的不足一個batch的丟棄。

  調用DataLoader獲得的結果是一個可迭代的對象,能夠和使用迭代器同樣使用它。

from torchvision import transforms as T
from torch.utils.data import DataLoader

transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])

if __name__ == '__main__':
    dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train', transform)
    data, label, index = dogcat[0]
    
    dataloader = DataLoader(dogcat,batch_size=3,shuffle=False,num_workers=0,drop_last=False)
    for batchDatas,batchLabels in dataloader: 
        train()

總結

  本文記錄了使用PyTorch進行數據預處理的相關操做流程,重點是掌握Dataset和DataLoader兩個類的使用,另外,視覺工具包torchvision的三個模塊靈活運用,會對數據處理過程有很好的幫助。

相關文章
相關標籤/搜索