圖像分類數據集 (FASHION-MNIST)

引入

  圖像分類數據集最經常使用的是手寫數字識別數據集MNIST (1),可是大部分模型在其上的分類精度都超過了95%。爲了更直觀地觀察算法之間的差別,將使用一個圖像內容更加複雜的數據集[Fashion-MNIST (2)]。
  接下來的部分將使用torchvision包,主要用於構建計算機視覺模型,主要由如下4部分組成:html

組成 功能
torchvision.datasets 加載數據的函數及經常使用的數據集接口
torchvision.models 包含經常使用的模型結構 (含預訓練模型)
torchvision.transforms 經常使用的圖片變化,例如裁剪、旋轉
torchvision…utils 其餘方法

  代碼已上傳至github:
  https://github.com/InkiInki/Python/blob/master/Python1/deepLearning/ImageMnist.pypython

1 獲取數據集

  須要導入的包以下:git

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import display

  下面,將經過torchvision.datasets下載數據集,第一次調用時會自動從網上獲取數據 (若出現速度較慢,請向後查看注意);經過參數train來指定獲取訓練集或者測試集;經過transform = transforms.Tensor()將數據轉化爲Tensor,若是不轉換,則返回PIL圖片。
  transforms.Tensor()將尺寸爲 ( H × W × C H×W×C H×W×C)且數據位於 (0, 255)的PIL圖片或數據類型爲np.uint8的Numpy轉換爲尺寸爲 ( C × H × W C×H×W C×H×W)且數據類型爲torch.float32且位於 (0.0, 1.0)的Tensor。github

  使用代碼以下:web

class ImageMnist():
    
    def __init__(self):
        self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',
            train=True, download=True, transform=transforms.ToTensor())
        self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',
            train=False, download=True, transform=transforms.ToTensor())

if __name__ == "__main__":
    test = ImageDataSet()
    test.__init__()
    print(test.mnist_train)
    print(len(test.mnist_train), len(test.mnist_test))

  運行結果:算法

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: C:\Users\Administrator/DataSets/FashionMNIST
    Split: Train
    StandardTransform
Transform: ToTensor()
60000 10000

  注意:
  1)若是用像素值表示圖片數據,那麼一概將其類型設置成unit8,以免沒必要要的bug;
  2)第一次下載時速度也許很慢,推薦在cmd中輸入如下代碼,並複製出現的http連接下載:
app

import torchvision
import torchvision.transforms as transforms
torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

2 簡單操做

  能夠經過下標來訪問任意一個樣本:svg

if __name__ == "__main__":
    test = ImageMnist()
    test.__init__()
    data, label = test.mnist_train[0]
    print(data.shape)
    print(label)

  運行結果:函數

torch.Size([1, 28, 28])    # 分別對應通道數、圖像高、圖像寬
9

  Fashion-MNIST共10個類別,分別爲t-shirt、trouser、pullover、dress、coat、sandal、shirt、sneaker、bag和ankle boot,如下函數能夠將數值標籤轉換成相應的文本標籤:學習

...
    def get_text_labels(self, labels):
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in labels]
        
if __name__ == "__main__":
    test = ImageMnist()
    test.__init__()
    data, label = test.mnist_train[0]
    print(test.get_text_labels([label]))

  運行結果:

['ankle boot']

  如今定義一個能夠在一行裏畫出多張圖像和對應標籤的函數:

...
    def show_mnist(self, images, labels):
        display.set_matplotlib_formats('svg')
        _, figs = plt.subplots(1, len(images), figsize=(12, 12))
        # zip()接受一系列可迭代對象做爲參數,將對象中對應的元素打包成一個個元組,而後返回由這些元組組成的列表
        for f, img, lbl in zip(figs, images, labels):
            f.imshow(img.view((28, 28)).numpy())
            f.set_title(lbl)
            f.axis('off')
        plt.show()
        
if __name__ == "__main__":
    test = ImageMnist()
    test.__init__()
    x, y = [], []
    for i in range(10):
        x.append(test.mnist_train[i][0])
        y.append(test.mnist_train[i][1])
    test.show_mnist(x, test.get_text_labels(y))

  運行結果:
在這裏插入圖片描述

3 讀取小批量

  torch的DataLoader中一個很方便的功能是運行使用多進程來加速讀取數據,這裏經過參數num_workers來設置4個進程讀取數據。

...
    def data_iter(self, batch_size=256):
        if sys.platform.startswith('win'):
            num_workers = 0    # 0表示不須要額外的進程來加速讀取數據
        else:
            num_workers = 4
        train_iter = torch.utils.data.DataLoader(self.mnist_train, 
            batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_iter = torch.utils.data.DataLoader(self.mnist_test, 
            batch_size=batch_size, shuffle=False, num_workers=num_workers)
        return train_iter, test_iter
        
if __name__ == "__main__":
    start = time.time()
    test = ImageMnist()
    test.__init__()
    train_iter, test_iter = test.data_iter()
    for x, y in train_iter:
        continue
    print("%.2f sec" % (time.time() - start))

  運行結果:

6.65 sec

4 完整代碼

''' @(#)test.py The class of test. Author: Yu-Xuan Zhang Email: inki.yinji@qq.com Created on May 05, 2020 Last Modified on May 05, 2020 @author: inki '''
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import display

class ImageMnist():
    
    def __init__(self):
        self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',
            train=True, download=True, transform=transforms.ToTensor())
        self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',
            train=False, download=True, transform=transforms.ToTensor())
        
    def get_text_labels(self, labels):
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in labels]
    
    def show_mnist(self, images, labels):
        display.set_matplotlib_formats('svg')
        _, figs = plt.subplots(1, len(images), figsize=(12, 12))
        for f, img, lbl in zip(figs, images, labels):
            f.imshow(img.view((28, 28)).numpy())
            f.set_title(lbl)
            f.axis('off')
        plt.show()
        
    def data_iter(self, batch_size=256):
        if sys.platform.startswith('win'):
            num_workers = 0
        else:
            num_workers = 4
        train_iter = torch.utils.data.DataLoader(self.mnist_train, 
            batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_iter = torch.utils.data.DataLoader(self.mnist_test, 
            batch_size=batch_size, shuffle=False, num_workers=num_workers)
        return train_iter, test_iter
        
if __name__ == "__main__":
    start = time.time()
    test = ImageMnist()
    test.__init__()
    train_iter, test_iter = test.data_iter()
    for x, y in train_iter:
        continue
    print("%.2f sec" % (time.time() - start))

致謝

  特別感謝李沐、Aston Zhang等老師的這本《動手學深度學習》一書~

本文分享 CSDN - 因吉。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索