圖片數據通常有兩種狀況:app
一、全部圖片放在一個文件夾內,另外有一個txt文件顯示標籤。spa
二、不一樣類別的圖片放在不一樣的文件夾內,文件夾就是圖片的類別。code
針對這兩種不一樣的狀況,數據集的準備也不相同,第一種狀況能夠自定義一個Dataset,第二種狀況直接調用torchvision.datasets.ImageFolder來處理。下面分別進行說明:orm
1、全部圖片放在一個文件夾內blog
這裏以mnist數據集的10000個test爲例, 我先把test集的10000個圖片保存出來,並生着對應的txt標籤文件。繼承
先在當前目錄建立一個空文件夾mnist_test, 用於保存10000張圖片,接着運行代碼:圖片
import torch import torchvision import matplotlib.pyplot as plt from skimage import io mnist_test= torchvision.datasets.MNIST( './mnist', train=False, download=True ) print('test set:', len(mnist_test)) f=open('mnist_test.txt','w') for i,(img,label) in enumerate(mnist_test): img_path="./mnist_test/"+str(i)+".jpg" io.imsave(img_path,img) f.write(img_path+' '+str(label)+'\n') f.close()
通過上面的操做,10000張圖片就保存在mnist_test文件夾裏了,並在當前目錄下生成了一個mnist_test.txt的文件,大體以下:ip
前期工做就裝備好了,接着就進入正題了:get
from torchvision import transforms, utils from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt from PIL import Image def default_loader(path): return Image.open(path).convert('RGB') class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0],int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor()) data_loader = DataLoader(train_data, batch_size=100,shuffle=True) print(len(data_loader)) def show_batch(imgs): grid = utils.make_grid(imgs) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader): if(i<4): print(i, batch_x.size(),batch_y.size()) show_batch(batch_x) plt.axis('off') plt.show()
自定義了一個MyDataset, 繼承自torch.utils.data.Dataset。而後利用torch.utils.data.DataLoader將整個數據集分紅多個批次。it
2、不一樣類別的圖片放在不一樣的文件夾內
一樣先準備數據,這裏以flowers數據集爲例,下載:
http://download.tensorflow.org/example_images/flower_photos.tgz
花總共有五類,分別放在5個文件夾下。大體以下圖:
個人路徑是d:/flowers/.
數據準備好了,就開始準備Dataset吧,這裏直接調用torchvision裏面的ImageFolder
import torch import torchvision from torchvision import transforms, utils import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower', transform=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()]) ) print(len(img_data)) data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True) print(len(data_loader)) def show_batch(imgs): grid = utils.make_grid(imgs,nrow=5) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader): if(i<4): print(i, batch_x.size(), batch_y.size()) show_batch(batch_x) plt.axis('off') plt.show()
就是這樣。