動手深度學習6-認識Fashion_MNIST圖像數據集

本節將使用torchvision包,它是服務於pytorch深度學習框架的,主要用來構建計算機視覺模型。
torchvision主要由如下幾個部分構成:python

  1. torchvision.datasets:一些加載數據的函數以及經常使用的數據集的接口
  2. torchvision.models: 包含經常使用的模型結構(含預訓練模型),例如AlexNet,VGG,ResNet;
  3. torchvision.transforms:經常使用的圖片變換,例如裁剪,旋轉等;
  4. torchvision.utils: 其餘的一些有用的方法
獲取數據集

導入本節須要的包或者模塊算法

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..')  # 爲了導入上層目錄的d2lzh_pytorch
import d2lzh_pytorch as d2l

經過調用torchvision的torchvision.datasets來下載這個數據集
能夠經過train參數獲取指定的訓練集或者測試集、
測試集只用了評估模型,並不用來訓練模型數組

同時指定了參數transform = transform.ToTensor()使全部數據轉化爲Tensor,若是不進行轉化,則返回的是PIL照片。
transform.ToTensor()將尺寸爲(H,W,C)且數據位於[0,255]的PIL圖片或者數據類型爲np.unit8的Numpy數組轉化爲(CxHxW)且數據類型爲torch.float32且位於[0.0,1.0]的Tensor。app

  • 若是用像素值(0,255)表示圖片數據,一概將其類型設置爲unit8,避免出問題
mnist_train= torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=False,transform=transforms.ToTensor())
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000
feature,label = mnist_train[0]
print(feature.shape,label)  # channel * height* width
torch.Size([1, 28, 28]) tensor(9)

feature對應的高和寬均爲28像素的圖像,因爲咱們使用了transforms.ToTensor(),因此每一個像素的數值爲[0,1]的32位浮點數。須要注意的是,feature的尺寸是(CxHxW)的,而不是(HxWxC)。第一維是通道數,由於數據集中是灰度圖像,因此通道數爲1,後面兩維分別是圖像的高和寬。框架

Fashion_MNIST中一共包括了10個類別,分別是t-shirt(T恤),trouser(褲子),pullover(套衫),dress(連衣裙),coat(外套),sandal(涼鞋),shirt(襯衫),sneaker(運動鞋),bag(包)和ankle boot(短靴)svg

import d2lzh_pytorch as d2l
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt','trouser','pullover','dress','coat','sandal',
                  'shirt','sneaker','bag','ankle boost'
                  ]
    return [text_labels[int(i)] for i in labels]



def show_fashion_mnist(images,labels):
    d2l.use_svg_display()
    _,figs = plt.subplots(1,len(images),figsize=(12,12))  # 1行10列
    for f ,img,lbl in zip(figs,images,labels):
        f.imshow(img.view((28,28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()
X,y = [],[]
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_test[i][1])
show_fashion_mnist(X,get_fashion_mnist_labels(y))

讀取小批量樣本

咱們將在訓練集上訓練模型,並將訓練好的模型預測測試集上評估模型的表現。
能夠用torch.utils.data.Dataloader來建立一個讀取小批量樣本的DataLoader實例。函數

在實際中,數據讀取常常是訓練的性能瓶頸,特別是當模型較爲簡單或者計算硬件性能較高時,pytorch的DataLoader中一個很方便的功能是容許使用多進程來加速數據讀取。這裏咱們經過參數num_workers來設置進程數來加速讀取數據性能

batch_size= 256

if sys.platform.startswith('win'):
    num_worker=0   # 表示不用額外的進程來加速讀取數據
    
else:
    num_worker=4
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_worker)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_worker)
start = time.time()
for X,y in train_iter:
    continue
print('%.2f sec' % (time.time()-start))
1.28 sec
小結
  • Fashion_MNIST 是一個10類服飾的分類數據集,以後章節後使用它來驗證不一樣算法的表現
  • 咱們將高和寬分別是H和W像素的圖像的形狀記爲HxW或(h,w)
相關文章
相關標籤/搜索