雖然torchvision.datasets中已經封裝了好多通用的數據集,可是咱們在使用Pytorch作深度學習任務的時候,會面臨着自定義數據庫來知足本身的任務須要。如咱們要訓練一我的臉關鍵點檢測算法,提供的訓練數據標註以下形式,存在CSV文件中:算法
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y 0805personali01.jpg,27,83,27,98, ... 84,134 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
在本次教程中,咱們須要用到兩個額外的包:數據庫
首先學習如何使用pandas庫解析csv文件學習
import pandas as pd
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv') n = 65 img_name = landmarks_frame.iloc[n, 0] landmarks = landmarks_frame.iloc[n, 1:].as_matrix() landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name)) print('Landmarks shape: {}'.format(landmarks.shape)) print('First 4 Landmarks: {}'.format(landmarks[:4]))
torch.utils.data.Dataset
是一個表示數據庫的抽象類,自定義數據庫須要繼承這個類,而且重寫其如下方法:spa
__len__ :返回數據庫的大小. __getitem__ :支持使用下標的方式 如dataset[i] 來獲取第i個樣本
如下建立人臉特徵點檢測的數據庫。咱們將在__init__中解析csv文件,而在__getitem__中讀取圖片。這樣能夠在須要圖片是才加載,內存效率高。此外,咱們還能夠先將數據集封裝成lmdb數據庫,讀取速度更快。code
import torch.utils.data.Dataset as Dataset class FaceLandmarksDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None): """ Args: csv_file (string): 到達標註文件cvs的路徑. root_dir (string): 全部圖片的根目錄. transform (callable, optional): (可選參數)對每個樣本進行轉換. """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0]) #第idx條數據的第一個字段,即文件名稱 image = io.imread(img_name) #讀取圖像數據 landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() #讀取第idx條數據的第二個字段及其以後的全部字段,即全部關鍵點的座標。而後轉成矩陣形式 landmarks = landmarks.astype('float').reshape(-1, 2) #將矩陣reshape成n行兩列矩陣 sample = {'image': image, 'landmarks': landmarks} #封裝數據 if self.transform: sample = self.transform(sample) #數據轉換 return sample #返回數據
注:__getitem__每次只返回一個條數據,至於batch的封裝能夠在DataLoader中設置batchsize,至於讀取速度能夠設置num_worker。orm