引入
圖像分類數據集最經常使用的是手寫數字識別數據集MNIST (1),可是大部分模型在其上的分類精度都超過了95%。爲了更直觀地觀察算法之間的差別,將使用一個圖像內容更加複雜的數據集[Fashion-MNIST (2)]。
接下來的部分將使用torchvision包,主要用於構建計算機視覺模型,主要由如下4部分組成:html
組成 | 功能 |
---|---|
torchvision.datasets | 加載數據的函數及經常使用的數據集接口 |
torchvision.models | 包含經常使用的模型結構 (含預訓練模型) |
torchvision.transforms | 經常使用的圖片變化,例如裁剪、旋轉 |
torchvision…utils | 其餘方法 |
代碼已上傳至github:
https://github.com/InkiInki/Python/blob/master/Python1/deepLearning/ImageMnist.pypython
1 獲取數據集
須要導入的包以下:git
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import sys from IPython import display
下面,將經過torchvision.datasets下載數據集,第一次調用時會自動從網上獲取數據 (若出現速度較慢,請向後查看注意);經過參數train來指定獲取訓練集或者測試集;經過transform = transforms.Tensor()將數據轉化爲Tensor,若是不轉換,則返回PIL圖片。
transforms.Tensor()將尺寸爲 ( H × W × C H×W×C H×W×C)且數據位於 (0, 255)的PIL圖片或數據類型爲np.uint8的Numpy轉換爲尺寸爲 ( C × H × W C×H×W C×H×W)且數據類型爲torch.float32且位於 (0.0, 1.0)的Tensor。github
使用代碼以下:web
class ImageMnist(): def __init__(self): self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor()) if __name__ == "__main__": test = ImageDataSet() test.__init__() print(test.mnist_train) print(len(test.mnist_train), len(test.mnist_test))
運行結果:算法
Dataset FashionMNIST Number of datapoints: 60000 Root location: C:\Users\Administrator/DataSets/FashionMNIST Split: Train StandardTransform Transform: ToTensor() 60000 10000
注意:
1)若是用像素值表示圖片數據,那麼一概將其類型設置成unit8,以免沒必要要的bug;
2)第一次下載時速度也許很慢,推薦在cmd中輸入如下代碼,並複製出現的http連接下載:app
import torchvision import torchvision.transforms as transforms torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
2 簡單操做
能夠經過下標來訪問任意一個樣本:svg
if __name__ == "__main__": test = ImageMnist() test.__init__() data, label = test.mnist_train[0] print(data.shape) print(label)
運行結果:函數
torch.Size([1, 28, 28]) # 分別對應通道數、圖像高、圖像寬 9
Fashion-MNIST共10個類別,分別爲t-shirt、trouser、pullover、dress、coat、sandal、shirt、sneaker、bag和ankle boot,如下函數能夠將數值標籤轉換成相應的文本標籤:學習
... def get_text_labels(self, labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] if __name__ == "__main__": test = ImageMnist() test.__init__() data, label = test.mnist_train[0] print(test.get_text_labels([label]))
運行結果:
['ankle boot']
如今定義一個能夠在一行裏畫出多張圖像和對應標籤的函數:
... def show_mnist(self, images, labels): display.set_matplotlib_formats('svg') _, figs = plt.subplots(1, len(images), figsize=(12, 12)) # zip()接受一系列可迭代對象做爲參數,將對象中對應的元素打包成一個個元組,而後返回由這些元組組成的列表 for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axis('off') plt.show() if __name__ == "__main__": test = ImageMnist() test.__init__() x, y = [], [] for i in range(10): x.append(test.mnist_train[i][0]) y.append(test.mnist_train[i][1]) test.show_mnist(x, test.get_text_labels(y))
運行結果:
3 讀取小批量
torch的DataLoader中一個很方便的功能是運行使用多進程來加速讀取數據,這裏經過參數num_workers來設置4個進程讀取數據。
... def data_iter(self, batch_size=256): if sys.platform.startswith('win'): num_workers = 0 # 0表示不須要額外的進程來加速讀取數據 else: num_workers = 4 train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter if __name__ == "__main__": start = time.time() test = ImageMnist() test.__init__() train_iter, test_iter = test.data_iter() for x, y in train_iter: continue print("%.2f sec" % (time.time() - start))
運行結果:
6.65 sec
4 完整代碼
''' @(#)test.py The class of test. Author: Yu-Xuan Zhang Email: inki.yinji@qq.com Created on May 05, 2020 Last Modified on May 05, 2020 @author: inki ''' import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import sys from IPython import display class ImageMnist(): def __init__(self): self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor()) def get_text_labels(self, labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] def show_mnist(self, images, labels): display.set_matplotlib_formats('svg') _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axis('off') plt.show() def data_iter(self, batch_size=256): if sys.platform.startswith('win'): num_workers = 0 else: num_workers = 4 train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter if __name__ == "__main__": start = time.time() test = ImageMnist() test.__init__() train_iter, test_iter = test.data_iter() for x, y in train_iter: continue print("%.2f sec" % (time.time() - start))
致謝
特別感謝李沐、Aston Zhang等老師的這本《動手學深度學習》一書~
本文分享 CSDN - 因吉。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。