深度學習中常常會使用一些基準數據集進行一些測試。其中 MNIST, Cifar 10, cifar100, Fashion-MNIST 數據集經常被人們拿來看成練手的數據集。爲了方便,諸如 Keras
、MXNet
、Tensorflow
都封裝了本身的基礎數據集,如 MNIST
、cifar
等。若是咱們要在不一樣平臺使用這些數據集,還須要瞭解那些框架是如何組織這些數據集的,須要花費一些沒必要要的時間學習它們的 API。爲此,咱們爲什麼不建立屬於本身的數據集呢?下面我僅僅使用了 Numpy
來實現數據集 MNIST
、Fashion MNIST
、Cifa 10
、Cifar 100
的操做,並封裝爲 HDF5,這樣該數據集的可擴展性就會大大的加強,而且還能夠被其餘的編程語言 (如 Matlab) 來獲取和使用。下面主要介紹如何經過建立的 API 來實現數據集的封裝。html
我使用了 Anaconda3
這個十分好用的包管理工具, 來減小管理和安裝一些必須的包。下面咱們載入該 API 必備的包:node
import struct # 處理二進制文件 import numpy as np # 對矩陣運算很友好 import gzip, tarfile # 對壓縮文件進行處理 import os # 管理本地文件 import pickle # 序列化和反序列化 import time # 記時
我是在 Jupyter Notebook 交互環境中運行代碼的。python
爲了更好的使用該 API, 我利用了 Bunch 結構。在 Python 中,咱們能夠定義 Bunch Pattern, 字面意思大概是指鏈式的束式結構。主要用於存儲鬆散的數據結構。git
它能讓咱們以命令行參數的形式建立相關對象,並設置任何屬性。下面咱們來看看 Bunch 的魅力!Bunch 的定義利用了 dict
的特性。github
class Bunch(dict): def __init__(self, *args, **kwds): super().__init__(*args, **kwds) self.__dict__ = self
下面咱們構建一個 Bunch 的實例 Tom
, 它表明一個住在北京的 54 歲的人。算法
Tom = Bunch(age="54", address="Beijing")
咱們能夠查看 Tom 的一些信息:數據庫
print('Tom 的年齡是 {},他住在 {}.'.format(Tom.age, Tom.address))
Tom 的年齡是 54,他住在 Beijing.
咱們還能夠直接對 Tom 增長屬性,好比:編程
Tom.sex = 'male' print(Tom)
{'age': '54', 'address': 'Beijing', 'sex': 'male'}
你也許會奇怪,Bunch 結構與 dict
結構好像沒有太大的的區別,只不過是多了一個點號運算,那麼,Bunch 到底有什麼神奇之處呢?咱們先看一個例子:json
T = Bunch t = T(left=T(left='a',right='b'), right=T(left='c')) for first in t: print('第一層的節點:', first) for second in t[first]: print('\t第二層的節點:', second) for node in t[first][second]: print('\t\t第三層的節點:', node)
第一層的節點: left 第二層的節點: left 第三層的節點: a 第二層的節點: right 第三層的節點: b 第一層的節點: right 第二層的節點: left 第三層的節點: c
從上面的輸出咱們能夠看出,t
即是一個簡單的二叉樹結構。這樣,咱們即可使用 Bunch 構建許多具備分層結構的數據類型。api
連接:
咱們將上述數據集均下載到同一個目錄下,好比:'E:/Data/Zip/'
,下面咱們將逐一介紹上述數據集。
MNIST 數據集能夠說是深度學習中的 hello world
級別的數據集,不少教程都是把它做爲入門級的數據集。不過有些人可能對它還不是很瞭解, 下面咱們簡單的瞭解一下!
MNIST 數據集來自美國國家標準與技術研究所(National Institute of Standards and Technology, NIST). 訓練集 (training set) 由來自 250 個不一樣人手寫的數字構成, 其中 \(50\%\) 是高中學生, \(50\%\) 來自人口普查局 (the Census Bureau) 的工做人員. 測試集(test set) 也是一樣比例的手寫數字數據.
MNIST 有一組 \(60\, 000\) 個樣本的訓練集和一組 \(10\, 000\) 個樣本的測試集。它是 NIST 的子集。數字圖像已被大小規範化, 並以固定大小的圖像居中。
MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個部分:
圖像分類數據集中最經常使用的是手寫數字識別數據集 MNIST1。但大部分模型在 MNIST 上的分類精度都超過了 \(95\%\)。爲了更直觀地觀察算法之間的差別,咱們可使用一個圖像內容更加複雜的數據集 Fashion-MNIST2。Fashion-MNIST 和 MNIST 同樣,也包括了 \(10\) 個類別,分別爲:t-shirt(T 恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和 ankle boot(短靴)。
Fashion-MNIST 的存儲方式和 MNIST 是同樣的,故而,咱們可使用相同的方式對其進行處理。
下面我以 MNIST
類來處理 MNIST 和 Fashion MNIST:
class MNIST: def __init__(self, root, namespace, train=True, transform=None): """ (MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist) (A dataset of Zalando's article images consisting of fashion products, a drop-in replacement of the original MNIST dataset from https://github.com/zalandoresearch/fashion-mnist) Each sample is an image (in 3D NDArray) with shape (28, 28, 1). Parameters ---------- root : 數據根目錄,如 'E:/Data/Zip/' namespace : 'mnist' or 'fashion_mnist' train : bool, default True Whether to load the training or testing set. transform : function, default None A user defined callback that transforms each sample. For example: :: transform=lambda data, label: (data.astype(np.float32)/255, label) """ self._train = train self.namespace = namespace root = root + namespace self._train_data = f'{root}/train-images-idx3-ubyte.gz' self._train_label = f'{root}/train-labels-idx1-ubyte.gz' self._test_data = f'{root}/t10k-images-idx3-ubyte.gz' self._test_label = f'{root}/t10k-labels-idx1-ubyte.gz' self._get_data() def _get_data(self): ''' 官方網站的數據是以 `[offset][type][value][description]` 的格式封裝的, 於是 `struct.unpack` 時須要注意 ''' if self._train: data, label = self._train_data, self._train_label else: data, label = self._test_data, self._test_label with gzip.open(label, 'rb') as fin: struct.unpack(">II", fin.read(8)) self.label = np.frombuffer(fin.read(), dtype=np.uint8) with gzip.open(data, 'rb') as fin: Y = struct.unpack(">IIII", fin.read(16)) data = np.frombuffer(fin.read(), dtype=np.uint8) self.data = data.reshape(Y[1:])
下面,咱們來看看如何載入這兩個數據集?
考慮到代碼的可複用性,我將上述代碼封裝在個人 GitHub3
。將其下載到本地,你即可以直接使用。下面我將展現如何使用該 API。
首先,須要找到你下載的 API 目錄,好比:D:\GitHub\basedataset\loader
,而後載入到你當前的 Python 環境變量中。
import sys sys.path.append('D:/GitHub/basedataset/loader/') from zdata import MNIST
下面你即可以自如的調用 MNIST 類了。
root = 'E:/Data/Zip/' namespace = 'mnist' train_mnist = MNIST(root, namespace, train=True, transform=None) # 獲取訓練集 test_mnist = MNIST(root, namespace, train=False, transform=None) # 獲取測試集 print('MNIST 的訓練集規模:{}'.format((train_mnist.data.shape))) print('MNIST 的測試集規模:{}'.format((test_mnist.data.shape)))
MNIST 的訓練集規模:(60000, 28, 28) MNIST 的測試集規模:(10000, 28, 28)
下面咱們以 MNIST 的測試集爲例,來看看 MNIST 具體長什麼樣吧!
from matplotlib import pyplot as plt def show_imgs(imgs): ''' 展現 多張圖片 ''' n = imgs.shape[0] h, w = 4, int(n / 4) _, figs = plt.subplots(h, w, figsize=(5, 5)) K = np.arange(n).reshape((h, w)) for i in range(h): for j in range(w): img = imgs[K[i, j]] figs[i][j].imshow(img) figs[i][j].axes.get_xaxis().set_visible(False) figs[i][j].axes.get_yaxis().set_visible(False) plt.show()
imgs = test_mnist.data[:16] show_imgs(imgs)
namespace = 'fashion_mnist' train_mnist_f = MNIST(root, namespace, train=True, transform=None) test_mnist_f = MNIST(root, namespace, train=False, transform=None) print('Fashion MNIST 的訓練集規模:{}'.format((train_mnist_f.data.shape))) print('Fashion MNIST 的測試集規模:{}'.format((test_mnist_f.data.shape)))
Fashion MNIST 的訓練集規模:(60000, 28, 28) Fashion MNIST 的測試集規模:(10000, 28, 28)
再看看 Fashion MNIST 具體長什麼樣吧!
imgs_f = test_mnist_f.data[:16] show_imgs(imgs_f)
MNIST 和 Fashion MNIST 數據集仍是太簡單了,爲了知足更多的需求,下面咱們將進入 Cifar 數據集的 API 開發和使用環節。
class Bunch(dict): def __init__(self, *args, **kwds): super().__init__(*args, **kwds) self.__dict__ = self class Cifar(Bunch): def __init__(self, root, namespace, transform=None, *args, **kwds): """CIFAR image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html Each sample is an image (in 3D NDArray) with shape (32, 32, 3). Parameters ---------- meta : 保存了類別信息 root : str, 數據根目錄 namespace : 'cifar-10' 或 'cifar-100' transform : function, default None A user defined callback that transforms each sample. For example: :: transform=lambda data, label: (data.astype(np.float32)/255, label) """ super().__init__(*args, **kwds) self.url = 'https://www.cs.toronto.edu/~kriz/cifar.html' self.namespace = namespace self._extract(root) self._read_batch() def _extract(self, root): tar_name = f'{root}{self.namespace}-python.tar.gz' names = extractall(tar_name, root) # print('載入數據的字典信息:') #start = time.time() for name in names: path = f'{root}{name}' if os.path.isfile(path): if not (path.endswith('.html') or path.endswith('.txt~')): k = name.split('/')[-1] if path.endswith('meta'): with open(path, 'rb') as fp: self['meta'] = pickle.load(fp) else: with open(path, 'rb') as fp: self[k] = pickle.load(fp, encoding='bytes') # #time.sleep(0.2) # t = int(time.time() - start) * '-' # print(t, end='') # print('\n載入數據的字典信息完畢!') def _read_batch(self): if self.namespace == 'cifar-10': self.trainX = np.concatenate([ self[f'data_batch_{str(i)}'][b'data'] for i in range(1, 6) ]).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) self.trainY = np.concatenate([ np.asanyarray(self[f'data_batch_{str(i)}'][b'labels']) for i in range(1, 6) ]) self.testX = self.test_batch[b'data'].reshape( -1, 3, 32, 32).transpose((0, 2, 3, 1)) self.testY = np.asanyarray(self.test_batch[b'labels']) elif self.namespace == 'cifar-100': self.trainX = self.train[b'data'].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) self.train_fine_labels = np.asanyarray( self.train[b'fine_labels']) # 子類標籤 self.train_coarse_labels = np.asanyarray( self.train[b'coarse_labels']) # 超類標籤 self.testX = self.test[b'data'].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) self.test_fine_labels = np.asanyarray( self.test[b'fine_labels']) # 子類標籤 self.test_coarse_labels = np.asanyarray( self.test[b'coarse_labels']) # 超類標籤
爲了方便管理和調用數據集,我定義了一個 DataBunch
類:
class DataBunch(Bunch): ''' 將數據集轉換爲 Bunch ''' def __init__(self, root, *args, **kwds): super().__init__(*args, **kwds) B = Bunch self.mnist = B(MNIST(root, 'mnist')) self.fashion_mnist = B(MNIST(root, 'fashion_mnist')) self.cifar10 = B(Cifar(root, 'cifar-10')) self.cifar100 = B(Cifar(root, 'cifar-100'))
一樣將上述代碼放入 zdata
模塊中。
下面咱們即可以直接利用 DataBunch
類來調用上述介紹的數據集:
import sys sys.path.append('D:/GitHub/basedataset/loader/') from zdata import DataBunch, show_imgs
root = 'E:/Data/Zip/' db = DataBunch(root)
咱們能夠查看,咱們封裝的數據集:
db.keys()
dict_keys(['mnist', 'fashion_mnist', 'cifar10', 'cifar100'])
因爲前面已經展現過 'mnist', 'fashion_mnist',下面咱們將展現 Cifar API 的使用。更多詳細內容參考個人博文 關於 『AI 專屬數據庫的定製』的改進4。
cifar-10 和 CIFAR-10 標記爲 \(8000\) 萬個 微小圖像數據集5的子集。它們是由 Alex Krizhevsky, Vinod Nair, 和 Geoffrey Hinton 收集的。
cifar-10 數據集由 \(10\) 類 \(32\times 32\) 彩色圖像組成, 每類有 \(6\,000\) 張圖像。被劃分爲 \(50\,000\) 張訓練圖像和 \(10\,000\) 張測試圖像。
cifar10 = db.cifar10 imgs = cifar10.trainX[:16] show_imgs(imgs)
爲了方便數據的使用,咱們能夠將 db
寫入到本地磁盤:
import pickle def write_bunch(path): ''' path:: 寫入數據集的文件路徑 ''' with open(path, 'wb') as fp: pickle.dump(db, fp)
root = 'E:/Data/Zip/' path = f'{root}X.json' # 寫入數據集的文件路徑 write_bunch(path)
這樣之後咱們就能夠直接複製 f'{root}X.dat
或 f'{root}X.json'
到你能夠放置的任何地方,而後你就能夠經過 load
函數來調用 MNIST
、Fashion MNIST
、Cifa 10
、Cifar 100
這些數據集。即:
def read_bunch(path): with open(path, 'rb') as fp: bunch = pickle.load(fp) # 即爲上面的 DataBunch 的實例 return bunch
db = read_bunch(path) # path 即你的數據集所在的路徑
考慮到 JSON 對於其餘編程語言的不友好,下面咱們將介紹如何將 Bunch 數據集存儲爲 HDF5 格式的數據。
PyTables
6 是 Python 與 HDF5 數據庫/文件標準的結合7。它專門爲優化 I/O 操做的性能、最大限度地利用可用硬件而設計,而且它還支持壓縮功能。
下面的代碼均是在 Jupyter NoteBook 下完成的:
import tables as tb import numpy as np def bunch2hdf5(root): ''' 這裏我僅僅封裝了 Cifar十、Cifar100、MNIST、Fashion MNIST 數據集, 使用者還能夠本身追加數據集。 ''' db = DataBunch(root) filters = tb.Filters(complevel=7, shuffle=False) # 這裏我採用了壓縮表,於是保存爲 `.h5c` 但也能夠保存爲 `.h5` with tb.open_file(f'{root}X.h5c', 'w', filters=filters, title='Xinet\'s dataset') as h5: for name in db.keys(): h5.create_group('/', name, title=f'{db[name].url}') if name != 'cifar100': h5.create_array(h5.root[name], 'trainX', db[name].trainX, title='訓練數據') h5.create_array(h5.root[name], 'trainY', db[name].trainY, title='訓練標籤') h5.create_array(h5.root[name], 'testX', db[name].testX, title='測試數據') h5.create_array(h5.root[name], 'testY', db[name].testY, title='測試標籤') else: h5.create_array(h5.root[name], 'trainX', db[name].trainX, title='訓練數據') h5.create_array(h5.root[name], 'testX', db[name].testX, title='測試數據') h5.create_array(h5.root[name], 'train_coarse_labels', db[name].train_coarse_labels, title='超類訓練標籤') h5.create_array(h5.root[name], 'test_coarse_labels', db[name].test_coarse_labels, title='超類測試標籤') h5.create_array(h5.root[name], 'train_fine_labels', db[name].train_fine_labels, title='子類訓練標籤') h5.create_array(h5.root[name], 'test_fine_labels', db[name].test_fine_labels, title='子類測試標籤') for k in ['cifar10', 'cifar100']: for name in db[k].meta.keys(): name = name.decode() if name.endswith('names'): label_names = np.asanyarray([label_name.decode() for label_name in db[k].meta[name.encode()]]) h5.create_array(h5.root[k], name, label_names, title='標籤名稱')
Bunch
到 HDF5
的轉換root = 'E:/Data/Zip/' bunch2hdf5(root)
h5c = tb.open_file('E:/Data/Zip/X.h5c') h5c
File(filename=E:/Data/Zip/X.h5c, title="Xinet's dataset", mode='r', root_uep='/', filters=Filters(complevel=7, complib='zlib', shuffle=False, bitshuffle=False, fletcher32=False, least_significant_digit=None)) / (RootGroup) "Xinet's dataset" /cifar10 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html' /cifar10/label_names (Array(10,)) '標籤名稱' atom := StringAtom(itemsize=10, shape=(), dflt=b'') maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /cifar10/testX (Array(10000, 32, 32, 3)) '測試數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /cifar10/testY (Array(10000,)) '測試標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /cifar10/trainX (Array(50000, 32, 32, 3)) '訓練數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /cifar10/trainY (Array(50000,)) '訓練標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /cifar100 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html' /cifar100/coarse_label_names (Array(20,)) '標籤名稱' atom := StringAtom(itemsize=30, shape=(), dflt=b'') maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /cifar100/fine_label_names (Array(100,)) '標籤名稱' atom := StringAtom(itemsize=13, shape=(), dflt=b'') maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /cifar100/testX (Array(10000, 32, 32, 3)) '測試數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /cifar100/test_coarse_labels (Array(10000,)) '超類測試標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /cifar100/test_fine_labels (Array(10000,)) '子類測試標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /cifar100/trainX (Array(50000, 32, 32, 3)) '訓練數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /cifar100/train_coarse_labels (Array(50000,)) '超類訓練標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /cifar100/train_fine_labels (Array(50000,)) '子類訓練標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /fashion_mnist (Group) 'https://github.com/zalandoresearch/fashion-mnist' /fashion_mnist/testX (Array(10000, 28, 28, 1)) '測試數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /fashion_mnist/testY (Array(10000,)) '測試標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /fashion_mnist/trainX (Array(60000, 28, 28, 1)) '訓練數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /fashion_mnist/trainY (Array(60000,)) '訓練標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /mnist (Group) 'http://yann.lecun.com/exdb/mnist' /mnist/testX (Array(10000, 28, 28, 1)) '測試數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /mnist/testY (Array(10000,)) '測試標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None /mnist/trainX (Array(60000, 28, 28, 1)) '訓練數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None /mnist/trainY (Array(60000,)) '訓練標籤' atom := Int32Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'little' chunkshape := None
從上面的結構可看出我將 Cifar10
、Cifar100
、MNIST
、Fashion MNIST
進行了封裝,而且還附帶了它們各類的數據集信息。好比標籤名,數字特徵(以數組的形式進行封裝)等。
%%time arr = h5c.root.cifar100.trainX.read() # 讀取數據十分快速
Wall time: 125 ms
arr.shape
(50000, 32, 32, 3)
h5c.root
/ (RootGroup) "Xinet's dataset" children := ['cifar10' (Group), 'cifar100' (Group), 'fashion_mnist' (Group), 'mnist' (Group)]
X.h5c
使用說明下面咱們以 Cifar100
爲例來展現咱們自創的數據集 X.h5c
(我將其上傳到了百度雲盤「連接:https://pan.baidu.com/s/12jzaJ2d2kvHCXbQa_HO6YQ 提取碼:2clg」能夠下載直接使用;亦可你本身生成,不過我推薦本身生成,能夠對數據集加深理解)
cifar100 = h5c.root.cifar100 cifar100
/cifar100 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html' children := ['coarse_label_names' (Array), 'fine_label_names' (Array), 'testX' (Array), 'test_coarse_labels' (Array), 'test_fine_labels' (Array), 'trainX' (Array), 'train_coarse_labels' (Array), 'train_fine_labels' (Array)]
'coarse_label_names'
指的是粗粒度或超類標籤名,'fine_label_names'
則是細粒度標籤名。
可使用 read()
方法直接獲取信息,也可使用索引的方式獲取。
coarse_label_names = cifar100.coarse_label_names[:] # 或者 coarse_label_names = cifar100.coarse_label_names.read() coarse_label_names.astype('str')
array(['aquatic_mammals', 'fish', 'flowers', 'food_containers', 'fruit_and_vegetables', 'household_electrical_devices', 'household_furniture', 'insects', 'large_carnivores', 'large_man-made_outdoor_things', 'large_natural_outdoor_scenes', 'large_omnivores_and_herbivores', 'medium_mammals', 'non-insect_invertebrates', 'people', 'reptiles', 'small_mammals', 'trees', 'vehicles_1', 'vehicles_2'], dtype='<U30')
fine_label_names = cifar100.fine_label_names[:].astype('str') fine_label_names
array(['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'], dtype='<U13')
'testX'
與 'trainX'
分別表明數據的測試數據和訓練數據,而其餘的節點所表明的含義也是相似的。
例如,咱們能夠看看訓練集的數據和標籤:
trainX = cifar100.trainX train_coarse_labels = cifar100.train_coarse_labels
array([11, 15, 4, ..., 8, 7, 1])
shape
爲 (50000, 32, 32, 3)
,數據的獲取,咱們同樣能夠採用索引的形式或者使用 read()
:
train_data = trainX[:] print(train_data[0].shape) print(train_data.dtype)
(32, 32, 3) uint8
固然,咱們也能夠直接使用 trainX
作運算。
for x in cifar100.trainX: y = x * 2 break print(y.shape)
(32, 32, 3)
h5c.get_node(h5c.root.cifar100, 'trainX')
/cifar100/trainX (Array(50000, 32, 32, 3)) '訓練數據' atom := UInt8Atom(shape=(), dflt=0) maindim := 0 flavor := 'numpy' byteorder := 'irrelevant' chunkshape := None
更甚者,咱們能夠直接定義迭代器來獲取數據:
trainX = cifar100.trainX train_coarse_labels = cifar100.train_coarse_labels
def data_iter(X, Y, batch_size): n = X.nrows idx = np.arange(n) if X.name.startswith('train'): np.random.shuffle(idx) for i in range(0, n ,batch_size): k = idx[i: min(n, i + batch_size)].tolist() yield np.take(X, k, 0), np.take(Y, k, 0)
for x, y in data_iter(trainX, train_coarse_labels, 8): print(x.shape, y) break
(8, 32, 32, 3) [ 7 7 0 15 4 8 8 3]
更多使用詳情見:使用 迭代器 獲取 Cifar 等經常使用數據集8
爲了更加形象的說明該數據集,咱們將其可視化:
from pylab import plt, mpl mpl.rcParams['font.sans-serif'] = ['SimHei'] # 指定默認字體 mpl.rcParams['axes.unicode_minus'] = False # 解決保存圖像是負號 '-' 顯示爲方塊的問題 def show_imgs(imgs, labels): ''' 展現 多張圖片 ''' imgs = np.transpose(imgs, (0, 2, 3, 1)) n = imgs.shape[0] h, w = 5, int(n / 5) fig, ax = plt.subplots(h, w, figsize=(7, 7)) K = np.arange(n).reshape((h, w)) names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U') names = names.reshape((h, w)) for i in range(h): for j in range(w): img = imgs[K[i, j]] ax[i][j].imshow(img) ax[i][j].axes.get_yaxis().set_visible(False) ax[i][j].axes.set_xlabel(names[i][j]) ax[i][j].set_xticks([]) plt.show()
爲了高效使用數據集 X.h5
,咱們使用迭代器的方式來獲取它:
class Loader: """ 方法 ======== L 爲該類的實例 len(L)::返回 batch 的批數 iter(L)::即爲數據迭代器 Return ======== 可迭代對象(numpy 對象) """ def __init__(self, X, Y, batch_size, shuffle): ''' X, Y 均爲類 numpy ''' self.X = X self.Y = Y self.batch_size = batch_size self.shuffle = shuffle def __iter__(self): n = len(self.X) idx = np.arange(n) if self.shuffle: np.random.shuffle(idx) for k in range(0, n, self.batch_size): K = idx[k:min(k + self.batch_size, n)].tolist() yield np.take(self.X, K, 0), np.take(self.Y, K, 0) def __len__(self): return round(len(self.X) / self.batch_size)
import tables as tb import numpy as np batch_size = 512 xpath = 'E:/xdata/X.h5' # 文件所在路徑 h5 = tb.open_file(xpath) cifar = h5.root.cifar100 train_cifar = Loader(cifar.trainX, cifar.train_fine_labels, batch_size, True) for imgs, labels in iter(train_cifar): break show_imgs(imgs[:25], labels[:25])
上面的大部分代碼被我放在了 Github:https://github.com/DataLoaderX/datasetsome/blob/master/dataloader/tabx.py。
上面的 API 設計過程當中,我發現到了許多自身的不足,不斷改進 API 的過程當中,我得到了學習和創造的喜悅。上面所介紹的 X.h5c
數據集不只僅是那些數據集的封裝,你還能夠繼續添加本身的數據集到該 數據庫中。同時,類 Loader
十分有用,它定義了一個標準,一個能夠延拓處處理其餘深度學習的數據集中去。
基於上述思想,我設計了以下 API:
LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/↩
Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.↩
https://github.com/DataLoaderX/datazone/tree/master/lab/utils/tools↩
https://yq.aliyun.com/articles/614332?spm=a2c4e.11155435.0.0.30543312vFsboY↩