訓練一個模型須要有一個數據庫,一個網絡,一個優化函數。數據讀取是訓練的第一步,如下是pytorch數據輸入框架。html
假設咱們已經定義了一個FaceLandmarksDataset數據庫,此數據庫將在如下創建。數據庫
import FaceLandmarksDataset face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/', transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor()]) )
或者使用torchvision.datasets裏封裝的數據集(MNIST、Fashion-MNIST、KMNIST、EMNIST、COCO、LSUN、ImageFolder、DatasetFolder、Imagenet-十二、CIFAR、STL十、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes)網絡
import torchvision.datasets imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
import torch.utils.data.DataLoader imagenet_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=4, shuffle=True, num_workers=4) #or facelandmark_loader = torch.utils.data.DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=4)
可見,數據加載器是通用的,只有數據庫實例不同,其它的都參數都同樣,參數值能夠根據任務須要本身調。框架
數據加載器可迭代的,咱們能夠使用數據庫:dom
for item in facelandmark_loader: images,labels = item
do_somethi
固然, 咱們也能夠直接對數據庫實例face_dataset進行下標操做,但這樣只可以每次獲取一條數據。函數
sample = face_dataset[index]