深度學習經常使用數據集 API(包括 Fashion MNIST)

基準數據集

深度學習中常常會使用一些基準數據集進行一些測試。其中 MNIST, Cifar 10, cifar100, Fashion-MNIST 數據集經常被人們拿來看成練手的數據集。爲了方便,諸如 KerasMXNetTensorflow 都封裝了本身的基礎數據集,如 MNISTcifar 等。若是咱們要在不一樣平臺使用這些數據集,還須要瞭解那些框架是如何組織這些數據集的,須要花費一些沒必要要的時間學習它們的 API。爲此,咱們爲什麼不建立屬於本身的數據集呢?下面我僅僅使用了 Numpy 來實現數據集 MNISTFashion MNISTCifa 10Cifar 100 的操做,並封裝爲 HDF5,這樣該數據集的可擴展性就會大大的加強,而且還能夠被其餘的編程語言 (如 Matlab) 來獲取和使用。下面主要介紹如何經過建立的 API 來實現數據集的封裝。html

環境搭建

我使用了 Anaconda3 這個十分好用的包管理工具, 來減小管理和安裝一些必須的包。下面咱們載入該 API 必備的包:node

import struct   # 處理二進制文件
import numpy as np   # 對矩陣運算很友好
import gzip, tarfile  # 對壓縮文件進行處理
import os          # 管理本地文件
import pickle      # 序列化和反序列化
import time       # 記時

我是在 Jupyter Notebook 交互環境中運行代碼的。python

Bunch 結構

爲了更好的使用該 API, 我利用了 Bunch 結構。在 Python 中,咱們能夠定義 Bunch Pattern, 字面意思大概是指鏈式的束式結構。主要用於存儲鬆散的數據結構。git

它能讓咱們以命令行參數的形式建立相關對象,並設置任何屬性。下面咱們來看看 Bunch 的魅力!Bunch 的定義利用了 dict 的特性。github

class Bunch(dict):
    
    def __init__(self, *args, **kwds):
        super().__init__(*args, **kwds)
        self.__dict__ = self

下面咱們構建一個 Bunch 的實例 Tom, 它表明一個住在北京的 54 歲的人。算法

Tom = Bunch(age="54", address="Beijing")

咱們能夠查看 Tom 的一些信息:數據庫

print('Tom 的年齡是 {},他住在 {}.'.format(Tom.age, Tom.address))
Tom 的年齡是 54,他住在 Beijing.

咱們還能夠直接對 Tom 增長屬性,好比:編程

Tom.sex = 'male'
print(Tom)
{'age': '54', 'address': 'Beijing', 'sex': 'male'}

你也許會奇怪,Bunch 結構與 dict 結構好像沒有太大的的區別,只不過是多了一個點號運算,那麼,Bunch 到底有什麼神奇之處呢?咱們先看一個例子:json

T = Bunch
t = T(left=T(left='a',right='b'), right=T(left='c'))

for first in t:
    print('第一層的節點:', first)
    for second in t[first]:
        print('\t第二層的節點:', second)
        for node in t[first][second]:
            print('\t\t第三層的節點:', node)
第一層的節點: left
    第二層的節點: left
        第三層的節點: a
    第二層的節點: right
        第三層的節點: b
第一層的節點: right
    第二層的節點: left
        第三層的節點: c

從上面的輸出咱們能夠看出,t 即是一個簡單的二叉樹結構。這樣,咱們即可使用 Bunch 構建許多具備分層結構的數據類型。api

下載數據集

連接:

咱們將上述數據集均下載到同一個目錄下,好比:'E:/Data/Zip/',下面咱們將逐一介紹上述數據集。

MNIST & Fashion MNIST

MNIST 數據集能夠說是深度學習中的 hello world 級別的數據集,不少教程都是把它做爲入門級的數據集。不過有些人可能對它還不是很瞭解, 下面咱們簡單的瞭解一下!

MNIST 數據集來自美國國家標準與技術研究所(National Institute of Standards and Technology, NIST). 訓練集 (training set) 由來自 250 個不一樣人手寫的數字構成, 其中 \(50\%\) 是高中學生, \(50\%\) 來自人口普查局 (the Census Bureau) 的工做人員. 測試集(test set) 也是一樣比例的手寫數字數據.

MNIST 有一組 \(60\, 000\) 個樣本的訓練集和一組 \(10\, 000\) 個樣本的測試集。它是 NIST 的子集。數字圖像已被大小規範化, 並以固定大小的圖像居中。

MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個部分:

圖像分類數據集中最經常使用的是手寫數字識別數據集 MNIST1。但大部分模型在 MNIST 上的分類精度都超過了 \(95\%\)。爲了更直觀地觀察算法之間的差別,咱們可使用一個圖像內容更加複雜的數據集 Fashion-MNIST2。Fashion-MNIST 和 MNIST 同樣,也包括了 \(10\) 個類別,分別爲:t-shirt(T 恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和 ankle boot(短靴)。

Fashion-MNIST 的存儲方式和 MNIST 是同樣的,故而,咱們可使用相同的方式對其進行處理。

MNIST 的使用

下面我以 MNIST 類來處理 MNIST 和 Fashion MNIST:

class MNIST:
    def __init__(self, root, namespace, train=True, transform=None):
        """
        (MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist)
        (A dataset of Zalando's article images consisting of fashion products,
        a drop-in replacement of the original MNIST dataset
        from https://github.com/zalandoresearch/fashion-mnist)

        Each sample is an image (in 3D NDArray) with shape (28, 28, 1).

        Parameters
        ----------
        root : 數據根目錄,如 'E:/Data/Zip/'
        namespace : 'mnist' or 'fashion_mnist'
        train : bool, default True
            Whether to load the training or testing set.
        transform : function, default None
            A user defined callback that transforms each sample. For example:
        ::

            transform=lambda data, label: (data.astype(np.float32)/255, label)
        """
        self._train = train
        self.namespace = namespace
        root = root + namespace
        self._train_data = f'{root}/train-images-idx3-ubyte.gz'
        self._train_label = f'{root}/train-labels-idx1-ubyte.gz'
        self._test_data = f'{root}/t10k-images-idx3-ubyte.gz'
        self._test_label = f'{root}/t10k-labels-idx1-ubyte.gz'
        self._get_data()

    def _get_data(self):
        '''
        官方網站的數據是以 `[offset][type][value][description]` 的格式封裝的,
        於是 `struct.unpack` 時須要注意
        '''
        if self._train:
            data, label = self._train_data, self._train_label
        else:
            data, label = self._test_data, self._test_label

        with gzip.open(label, 'rb') as fin:
            struct.unpack(">II", fin.read(8))
            self.label = np.frombuffer(fin.read(), dtype=np.uint8)

        with gzip.open(data, 'rb') as fin:
            Y = struct.unpack(">IIII", fin.read(16))
            data = np.frombuffer(fin.read(), dtype=np.uint8)
            self.data = data.reshape(Y[1:])

下面,咱們來看看如何載入這兩個數據集?

MNIST

考慮到代碼的可複用性,我將上述代碼封裝在個人 GitHub3
。將其下載到本地,你即可以直接使用。下面我將展現如何使用該 API。

首先,須要找到你下載的 API 目錄,好比:D:\GitHub\basedataset\loader,而後載入到你當前的 Python 環境變量中。

import sys

sys.path.append('D:/GitHub/basedataset/loader/')

from zdata import MNIST

下面你即可以自如的調用 MNIST 類了。

root = 'E:/Data/Zip/'
namespace = 'mnist'
train_mnist = MNIST(root, namespace, train=True, transform=None)  # 獲取訓練集
test_mnist = MNIST(root, namespace, train=False, transform=None)  # 獲取測試集

print('MNIST 的訓練集規模:{}'.format((train_mnist.data.shape)))
print('MNIST 的測試集規模:{}'.format((test_mnist.data.shape)))
MNIST 的訓練集規模:(60000, 28, 28)
MNIST 的測試集規模:(10000, 28, 28)

下面咱們以 MNIST 的測試集爲例,來看看 MNIST 具體長什麼樣吧!

from matplotlib import pyplot as plt


def show_imgs(imgs):
    '''
    展現 多張圖片
    '''
    n = imgs.shape[0]
    h, w = 4, int(n / 4)
    _, figs = plt.subplots(h, w, figsize=(5, 5))
    K = np.arange(n).reshape((h, w))
    for i in range(h):
        for j in range(w):
            img = imgs[K[i, j]]
            figs[i][j].imshow(img)
            figs[i][j].axes.get_xaxis().set_visible(False)
            figs[i][j].axes.get_yaxis().set_visible(False)
    plt.show()
imgs = test_mnist.data[:16]
show_imgs(imgs)

Fashion MNIST
namespace = 'fashion_mnist'
train_mnist_f = MNIST(root, namespace, train=True, transform=None)
test_mnist_f = MNIST(root, namespace, train=False, transform=None)

print('Fashion MNIST 的訓練集規模:{}'.format((train_mnist_f.data.shape)))
print('Fashion MNIST 的測試集規模:{}'.format((test_mnist_f.data.shape)))
Fashion MNIST 的訓練集規模:(60000, 28, 28)
Fashion MNIST 的測試集規模:(10000, 28, 28)

再看看 Fashion MNIST 具體長什麼樣吧!

imgs_f = test_mnist_f.data[:16]
show_imgs(imgs_f)

MNIST 和 Fashion MNIST 數據集仍是太簡單了,爲了知足更多的需求,下面咱們將進入 Cifar 數據集的 API 開發和使用環節。

Cifar API

class Bunch(dict):
    def __init__(self, *args, **kwds):
        super().__init__(*args, **kwds)
        self.__dict__ = self


class Cifar(Bunch):
    def __init__(self, root, namespace, transform=None, *args, **kwds):
        """CIFAR image classification dataset
         from https://www.cs.toronto.edu/~kriz/cifar.html

        Each sample is an image (in 3D NDArray) with shape (32, 32, 3).

        Parameters
        ----------
        meta : 保存了類別信息
        root : str, 數據根目錄
        namespace : 'cifar-10' 或 'cifar-100'
        transform : function, default None
            A user defined callback that transforms each sample. For example:
        ::

            transform=lambda data, label: (data.astype(np.float32)/255, label)

        """
        super().__init__(*args, **kwds)
        self.url = 'https://www.cs.toronto.edu/~kriz/cifar.html'
        self.namespace = namespace
        self._extract(root)
        self._read_batch()

    def _extract(self, root):
        tar_name = f'{root}{self.namespace}-python.tar.gz'
        names = extractall(tar_name, root)
        # print('載入數據的字典信息:')
        #start = time.time()
        for name in names:
            path = f'{root}{name}'

            if os.path.isfile(path):
                if not (path.endswith('.html') or path.endswith('.txt~')):
                    k = name.split('/')[-1]
                    if path.endswith('meta'):
                        with open(path, 'rb') as fp:
                            self['meta'] = pickle.load(fp)
                    else:
                        with open(path, 'rb') as fp:
                            self[k] = pickle.load(fp, encoding='bytes')

        #     #time.sleep(0.2)
        #     t = int(time.time() - start) * '-'
        #     print(t, end='')
        # print('\n載入數據的字典信息完畢!')

    def _read_batch(self):
        if self.namespace == 'cifar-10':
            self.trainX = np.concatenate([
                self[f'data_batch_{str(i)}'][b'data'] for i in range(1, 6)
            ]).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
            self.trainY = np.concatenate([
                np.asanyarray(self[f'data_batch_{str(i)}'][b'labels'])
                for i in range(1, 6)
            ])
            self.testX = self.test_batch[b'data'].reshape(
                -1, 3, 32, 32).transpose((0, 2, 3, 1))
            self.testY = np.asanyarray(self.test_batch[b'labels'])
        elif self.namespace == 'cifar-100':
            self.trainX = self.train[b'data'].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
            self.train_fine_labels = np.asanyarray(
                self.train[b'fine_labels'])  # 子類標籤
            self.train_coarse_labels = np.asanyarray(
                self.train[b'coarse_labels'])  # 超類標籤
            self.testX = self.test[b'data'].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
            self.test_fine_labels = np.asanyarray(
                self.test[b'fine_labels'])  # 子類標籤
            self.test_coarse_labels = np.asanyarray(
                self.test[b'coarse_labels'])  # 超類標籤

爲了方便管理和調用數據集,我定義了一個 DataBunch 類:

class DataBunch(Bunch):
    '''
    將數據集轉換爲 Bunch
    '''
    def __init__(self, root, *args, **kwds):
        super().__init__(*args, **kwds)
        B = Bunch
        self.mnist = B(MNIST(root, 'mnist'))
        self.fashion_mnist = B(MNIST(root, 'fashion_mnist'))
        self.cifar10 = B(Cifar(root, 'cifar-10'))
        self.cifar100 = B(Cifar(root, 'cifar-100'))

一樣將上述代碼放入 zdata 模塊中。

Cifar 10 數據集

下面咱們即可以直接利用 DataBunch 類來調用上述介紹的數據集:

import sys

sys.path.append('D:/GitHub/basedataset/loader/')

from zdata import DataBunch, show_imgs
root = 'E:/Data/Zip/'
db = DataBunch(root)

咱們能夠查看,咱們封裝的數據集:

db.keys()
dict_keys(['mnist', 'fashion_mnist', 'cifar10', 'cifar100'])

因爲前面已經展現過 'mnist', 'fashion_mnist',下面咱們將展現 Cifar API 的使用。更多詳細內容參考個人博文 關於 『AI 專屬數據庫的定製』的改進4

cifar-10 和 CIFAR-10 標記爲 \(8000\) 萬個 微小圖像數據集5的子集。它們是由 Alex Krizhevsky, Vinod Nair, 和 Geoffrey Hinton 收集的。

cifar-10 數據集由 \(10\)\(32\times 32\) 彩色圖像組成, 每類有 \(6\,000\) 張圖像。被劃分爲 \(50\,000\) 張訓練圖像和 \(10\,000\) 張測試圖像。

cifar10 = db.cifar10

imgs = cifar10.trainX[:16]
show_imgs(imgs)

爲了方便數據的使用,咱們能夠將 db 寫入到本地磁盤:

序列化

import pickle

def write_bunch(path):
    '''
    path:: 寫入數據集的文件路徑
    '''
    with open(path, 'wb') as fp:
        pickle.dump(db, fp)
root = 'E:/Data/Zip/'
path = f'{root}X.json'   # 寫入數據集的文件路徑
write_bunch(path)

這樣之後咱們就能夠直接複製 f'{root}X.datf'{root}X.json' 到你能夠放置的任何地方,而後你就能夠經過 load 函數來調用 MNISTFashion MNISTCifa 10Cifar 100 這些數據集。即:

反序列化

def read_bunch(path):
    with open(path, 'rb') as fp:
        bunch = pickle.load(fp)  # 即爲上面的 DataBunch 的實例
    return bunch
db = read_bunch(path)   # path 即你的數據集所在的路徑

考慮到 JSON 對於其餘編程語言的不友好,下面咱們將介紹如何將 Bunch 數據集存儲爲 HDF5 格式的數據。

Bunch 轉換爲 HDF5 文件:高效存儲 Cifar 等數據集

PyTables6Python 與 HDF5 數據庫/文件標準的結合7。它專門爲優化 I/O 操做的性能、最大限度地利用可用硬件而設計,而且它還支持壓縮功能。

下面的代碼均是在 Jupyter NoteBook 下完成的:

import tables as tb
import numpy as np


def bunch2hdf5(root):
    '''
    這裏我僅僅封裝了 Cifar十、Cifar100、MNIST、Fashion MNIST 數據集,
    使用者還能夠本身追加數據集。
    '''
    db = DataBunch(root)
    filters = tb.Filters(complevel=7, shuffle=False)
    # 這裏我採用了壓縮表,於是保存爲 `.h5c` 但也能夠保存爲 `.h5`
    with tb.open_file(f'{root}X.h5c', 'w', filters=filters, title='Xinet\'s dataset') as h5:
        for name in db.keys():
            h5.create_group('/', name, title=f'{db[name].url}')
            if name != 'cifar100':
                h5.create_array(h5.root[name], 'trainX', db[name].trainX, title='訓練數據')
                h5.create_array(h5.root[name], 'trainY', db[name].trainY, title='訓練標籤')
                h5.create_array(h5.root[name], 'testX', db[name].testX, title='測試數據')
                h5.create_array(h5.root[name], 'testY', db[name].testY, title='測試標籤')
            else:
                h5.create_array(h5.root[name], 'trainX', db[name].trainX, title='訓練數據')
                h5.create_array(h5.root[name], 'testX', db[name].testX, title='測試數據')
                h5.create_array(h5.root[name], 'train_coarse_labels', db[name].train_coarse_labels, title='超類訓練標籤')
                h5.create_array(h5.root[name], 'test_coarse_labels', db[name].test_coarse_labels, title='超類測試標籤')
                h5.create_array(h5.root[name], 'train_fine_labels', db[name].train_fine_labels, title='子類訓練標籤')
                h5.create_array(h5.root[name], 'test_fine_labels', db[name].test_fine_labels, title='子類測試標籤')

        for k in ['cifar10', 'cifar100']:
            for name in db[k].meta.keys():
                name = name.decode()
                if name.endswith('names'):
                    label_names = np.asanyarray([label_name.decode() for label_name in db[k].meta[name.encode()]])
                    h5.create_array(h5.root[k], name, label_names, title='標籤名稱')

完成 BunchHDF5 的轉換

root = 'E:/Data/Zip/'
bunch2hdf5(root)
h5c = tb.open_file('E:/Data/Zip/X.h5c')
h5c
File(filename=E:/Data/Zip/X.h5c, title="Xinet's dataset", mode='r', root_uep='/', filters=Filters(complevel=7, complib='zlib', shuffle=False, bitshuffle=False, fletcher32=False, least_significant_digit=None))
/ (RootGroup) "Xinet's dataset"
/cifar10 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html'
/cifar10/label_names (Array(10,)) '標籤名稱'
  atom := StringAtom(itemsize=10, shape=(), dflt=b'')
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/cifar10/testX (Array(10000, 32, 32, 3)) '測試數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/cifar10/testY (Array(10000,)) '測試標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/cifar10/trainX (Array(50000, 32, 32, 3)) '訓練數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/cifar10/trainY (Array(50000,)) '訓練標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/cifar100 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html'
/cifar100/coarse_label_names (Array(20,)) '標籤名稱'
  atom := StringAtom(itemsize=30, shape=(), dflt=b'')
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/cifar100/fine_label_names (Array(100,)) '標籤名稱'
  atom := StringAtom(itemsize=13, shape=(), dflt=b'')
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/cifar100/testX (Array(10000, 32, 32, 3)) '測試數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/cifar100/test_coarse_labels (Array(10000,)) '超類測試標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/cifar100/test_fine_labels (Array(10000,)) '子類測試標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/cifar100/trainX (Array(50000, 32, 32, 3)) '訓練數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/cifar100/train_coarse_labels (Array(50000,)) '超類訓練標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/cifar100/train_fine_labels (Array(50000,)) '子類訓練標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/fashion_mnist (Group) 'https://github.com/zalandoresearch/fashion-mnist'
/fashion_mnist/testX (Array(10000, 28, 28, 1)) '測試數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/fashion_mnist/testY (Array(10000,)) '測試標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/fashion_mnist/trainX (Array(60000, 28, 28, 1)) '訓練數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/fashion_mnist/trainY (Array(60000,)) '訓練標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/mnist (Group) 'http://yann.lecun.com/exdb/mnist'
/mnist/testX (Array(10000, 28, 28, 1)) '測試數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/mnist/testY (Array(10000,)) '測試標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None
/mnist/trainX (Array(60000, 28, 28, 1)) '訓練數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None
/mnist/trainY (Array(60000,)) '訓練標籤'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'little'
  chunkshape := None

從上面的結構可看出我將 Cifar10Cifar100MNISTFashion MNIST 進行了封裝,而且還附帶了它們各類的數據集信息。好比標籤名,數字特徵(以數組的形式進行封裝)等。

%%time
arr = h5c.root.cifar100.trainX.read() # 讀取數據十分快速
Wall time: 125 ms
arr.shape
(50000, 32, 32, 3)
h5c.root
/ (RootGroup) "Xinet's dataset"
  children := ['cifar10' (Group), 'cifar100' (Group), 'fashion_mnist' (Group), 'mnist' (Group)]

X.h5c 使用說明

下面咱們以 Cifar100 爲例來展現咱們自創的數據集 X.h5c(我將其上傳到了百度雲盤「連接:https://pan.baidu.com/s/12jzaJ2d2kvHCXbQa_HO6YQ 提取碼:2clg」能夠下載直接使用;亦可你本身生成,不過我推薦本身生成,能夠對數據集加深理解)

cifar100 = h5c.root.cifar100
cifar100
/cifar100 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html'
  children := ['coarse_label_names' (Array), 'fine_label_names' (Array), 'testX' (Array), 'test_coarse_labels' (Array), 'test_fine_labels' (Array), 'trainX' (Array), 'train_coarse_labels' (Array), 'train_fine_labels' (Array)]

'coarse_label_names' 指的是粗粒度或超類標籤名,'fine_label_names' 則是細粒度標籤名。

可使用 read() 方法直接獲取信息,也可使用索引的方式獲取。

coarse_label_names = cifar100.coarse_label_names[:]
# 或者
coarse_label_names = cifar100.coarse_label_names.read()
coarse_label_names.astype('str')
array(['aquatic_mammals', 'fish', 'flowers', 'food_containers',
       'fruit_and_vegetables', 'household_electrical_devices',
       'household_furniture', 'insects', 'large_carnivores',
       'large_man-made_outdoor_things', 'large_natural_outdoor_scenes',
       'large_omnivores_and_herbivores', 'medium_mammals',
       'non-insect_invertebrates', 'people', 'reptiles', 'small_mammals',
       'trees', 'vehicles_1', 'vehicles_2'], dtype='<U30')
fine_label_names = cifar100.fine_label_names[:].astype('str')
fine_label_names
array(['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
       'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus',
       'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
       'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch',
       'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant',
       'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house',
       'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
       'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
       'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter',
       'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate',
       'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road',
       'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk',
       'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar',
       'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone',
       'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
       'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
       'worm'], dtype='<U13')

'testX''trainX' 分別表明數據的測試數據和訓練數據,而其餘的節點所表明的含義也是相似的。

例如,咱們能夠看看訓練集的數據和標籤:

trainX = cifar100.trainX
train_coarse_labels = cifar100.train_coarse_labels
 
array([11, 15,  4, ...,  8,  7,  1])

shape(50000, 32, 32, 3),數據的獲取,咱們同樣能夠採用索引的形式或者使用 read()

train_data = trainX[:]
print(train_data[0].shape)
print(train_data.dtype)
(32, 32, 3)
uint8

固然,咱們也能夠直接使用 trainX 作運算。

for x in cifar100.trainX:
    y = x * 2
    break

print(y.shape)
(32, 32, 3)
h5c.get_node(h5c.root.cifar100, 'trainX')
/cifar100/trainX (Array(50000, 32, 32, 3)) '訓練數據'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := 'numpy'
  byteorder := 'irrelevant'
  chunkshape := None

更甚者,咱們能夠直接定義迭代器來獲取數據:

trainX = cifar100.trainX
train_coarse_labels = cifar100.train_coarse_labels
def data_iter(X, Y, batch_size):
    n = X.nrows
    idx = np.arange(n)
    if X.name.startswith('train'):
        np.random.shuffle(idx)
    for i in range(0, n ,batch_size):
        k = idx[i: min(n, i + batch_size)].tolist()
        yield np.take(X, k, 0), np.take(Y, k, 0)
for x, y in data_iter(trainX, train_coarse_labels, 8):
    print(x.shape, y)
    break
(8, 32, 32, 3) [ 7  7  0 15  4  8  8  3]

更多使用詳情見:使用 迭代器 獲取 Cifar 等經常使用數據集8

爲了更加形象的說明該數據集,咱們將其可視化:

from pylab import plt, mpl


mpl.rcParams['font.sans-serif'] = ['SimHei']  # 指定默認字體
mpl.rcParams['axes.unicode_minus'] = False  # 解決保存圖像是負號 '-' 顯示爲方塊的問題


def show_imgs(imgs, labels):
    '''
    展現 多張圖片
    '''
    imgs = np.transpose(imgs, (0, 2, 3, 1))
    n = imgs.shape[0]
    h, w = 5, int(n / 5)
    fig, ax = plt.subplots(h, w, figsize=(7, 7))
    K = np.arange(n).reshape((h, w))
    names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U')
    names = names.reshape((h, w))
    for i in range(h):
        for j in range(w):
            img = imgs[K[i, j]]
            ax[i][j].imshow(img)
            ax[i][j].axes.get_yaxis().set_visible(False)
            ax[i][j].axes.set_xlabel(names[i][j])
            ax[i][j].set_xticks([])
    plt.show()

爲了高效使用數據集 X.h5,咱們使用迭代器的方式來獲取它:

class Loader:
    """
    方法
    ========
    L 爲該類的實例
    len(L)::返回 batch 的批數
    iter(L)::即爲數據迭代器

    Return
    ========
    可迭代對象(numpy 對象)
    """

    def __init__(self, X, Y, batch_size, shuffle):
        '''
        X, Y 均爲類 numpy 
        '''
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        n = len(self.X)
        idx = np.arange(n)

        if self.shuffle:
            np.random.shuffle(idx)

        for k in range(0, n, self.batch_size):
            K = idx[k:min(k + self.batch_size, n)].tolist()
            yield np.take(self.X, K, 0), np.take(self.Y, K, 0)

    def __len__(self):
        return round(len(self.X) / self.batch_size)
import tables as tb
import numpy as np

batch_size = 512
xpath = 'E:/xdata/X.h5'  # 文件所在路徑
h5 = tb.open_file(xpath)
cifar = h5.root.cifar100
train_cifar = Loader(cifar.trainX, cifar.train_fine_labels, batch_size, True)

for imgs, labels in iter(train_cifar):
    break

show_imgs(imgs[:25], labels[:25])

上面的大部分代碼被我放在了 Github:https://github.com/DataLoaderX/datasetsome/blob/master/dataloader/tabx.py

總結

上面的 API 設計過程當中,我發現到了許多自身的不足,不斷改進 API 的過程當中,我得到了學習和創造的喜悅。上面所介紹的 X.h5c 數據集不只僅是那些數據集的封裝,你還能夠繼續添加本身的數據集到該 數據庫中。同時,類 Loader 十分有用,它定義了一個標準,一個能夠延拓處處理其餘深度學習的數據集中去。

基於上述思想,我設計了以下 API:


  1. LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/

  2. Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.

  3. https://github.com/DataLoaderX/datazone/tree/master/lab/utils/tools

  4. https://www.jianshu.com/p/29066e70ea5e

  5. http://people.csail.mit.edu/torralba/tinyimages/

  6. http://www.pytables.org/

  7. http://www.hdfgroup.org

  8. https://yq.aliyun.com/articles/614332?spm=a2c4e.11155435.0.0.30543312vFsboY

相關文章
相關標籤/搜索