訓練一個機器學習深度學習模型通常能夠簡單歸納爲如下三個步驟:python
咱們能夠把整個過程用下面的一個Pipeline圖例來表示。
後端
其中的reader
就主要負責把數據按必定的格式feed
到深度學習網絡的輸入層上。不一樣的深度學習框架對爲放進網絡中的數據格式要求不同。在MXNet中對於Module的訓練與推理接口要求的數據都是一個data iterator
。下面咱們會詳細來介紹MXNet中的Data Iterator。網絡
MXNet裏的Date Iterators與Python中的iterator object很是相似。在Python中,有一類被稱爲iterable的對象,它容許咱們使用其中的next
方法來順序的抽取元素,好比list。迭代法器提供了一種遍歷整個容器的簡便方法,而不用關心容器具體的內容。框架
在MXNet中,data iterators
每次返回一個DataBatch
。一個DataBatch
通常包含n
個訓練樣本以及它們對應的標籤。這裏的n
通常等於指定的batch size,當整個數據流迭代到尾巴,沒有更多的數據返回時,迭代器將返回一個StopIteration
的異常。DataBatch裏包含了一些關於樣本的信息:名稱,形狀,數據類型以及內在佈局,能夠經過provide_data
和provide_label
這兩個訪法返回的DataDesc
對象來獲取。dom
全部MXNet關於IO的處理都是由mx.io.DataIter
以及它的子類來完成的。機器學習
下面咱們經過使用幾個典型的DataIter來講明它的用法。分佈式
當數據是在內存中,以NDArray或者numpy中的ndarray的形式存在時,咱們可使用NDArrayIter
來讀取。ide
import mxnet as mx %matplotlib inline import os import sys import subprocess import numpy as np import matplotlib.pyplot as plt import tarfile import warnings warnings.filterwarnings("ignore", category=DeprecationWarning)
import numpy as np data = np.random.rand(100,3) label = np.random.randint(0, 10, (100,)) data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30) for batch in data_iter: print([batch.data, batch.label, batch.pad])
MXNet提供了CSVIter
來方便使用者直接從一個CSV文件中讀取數據函數
#lets save `data` into a csv file first and try reading it back np.savetxt('data.csv', data, delimiter=',') data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30) for batch in data_iter: print([batch.data, batch.pad])
當MXNet提供的一些數據迭代器不知足咱們的需求時,咱們能夠本身寫一個數據迭代器。那麼一個數據迭代器的對象,必定要包括下面幾個方法:佈局
__next()__
(python3),該方法返回一個DataBatch對象,而且當沒有剩餘數據時,返回一個StopIteration
的異常reset()
方法用於重置數據讀取到開始的位置provide_data
屬性,它是一個DataDesc對象的list,存儲了數據的名稱,形狀,數據類型及內在佈局信息。provide_label
屬性,它是一個DataDesc對象的list,存儲了標籤的名稱,形狀,數據類型及內在佈局信息。當咱們建立一個新的iterator時,咱們能夠選擇從頭建立,也能夠選擇從一個已經存在的迭代器那擴展。好比果咱們要作圖像描述(image captioning)的應用。那輸入的數據是圖像,而對應的Label是一個句子。那咱們可使用ImageRecordIter
建立一個image_iter
,而後經過NDArrayIter
建立一個caption_iter
。咱們的nxet()
方法將返回image_iter.next()
與caption_iter.next()
的一個合併。
下面是咱們自定義的一個迭代器。
class SimpleIter(mx.io.DataIter): def __init__(self, data_names, data_shapes, data_gen, label_names, label_shapes, label_gen, num_batches=10): self._provide_data = list(zip(data_names, data_shapes)) self._provide_label = list(zip(label_names, label_shapes)) self.num_batches = num_batches self.data_gen = data_gen self.label_gen = label_gen self.cur_batch = 0 def __iter__(self): return self def reset(self): self.cur_batch = 0 def __next__(self): return self.next() @property def provide_data(self): return self._provide_data @property def provide_label(self): return self._provide_label def next(self): if self.cur_batch < self.num_batches: self.cur_batch += 1 data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)] label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)] return mx.io.DataBatch(data, label) else: raise StopIteration
Record IO是MXNet爲了數據IO設計的一種文件格式。它將數據打包成一種十分便於在分佈式存儲系統,如HDFS和AWS S3上進行高效讀取的數據塊。MXNet提供了MXRecordIO
用於順序數據存儲的狀況,提供了MXIndexedRecordIO
用於隨機數據存取的狀況。
咱們先經過一個例子說明MXRecordIO用於順序數據讀寫的用法。
def str_or_bytes(str): """ A utility function for this tutorial that helps us convert string to bytes if we are using python3. Parameters ---------- str : string Returns ------- string (python2) or bytes (python3) """ if sys.version_info[0] < 3: return str else: return bytes(str, 'utf-8')
咱們將幾個連續的字符串寫到一個以.rec
結尾的文件中
record = mx.recordio.MXRecordIO('tmp.rec', 'w') for i in range(5): record.write(str_or_bytes('record_%d'%i)) record.close()
咱們再從一個.rec
文件中來順序的讀取
record = mx.recordio.MXRecordIO('tmp.rec', 'r') while True: item = record.read() if not item: break print (item) record.close()
不一樣與MXRecordIO對象,咱們只能不斷的調用read()
方法來順序的獲取裏面的數據。MXIndexedRecordIO
能夠隨機的訪問。
record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w') for i in range(5): record.write_idx(i, str_or_bytes('record_%d'%i)) record.close()
record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r') record.read_idx(3)
# 也能夠單獨的把index輸出出來 record.keys
咱們放到RecordIO
裏面包含的是一個個pack,它能夠是任何二進制數據。可是對於大部分深度學習的任務來講,咱們每每須要的是數據/標籤
這樣的格式。mx.recordio
提供了一些接口函數來進行這些操做。
# pack data = 'data' label1 = 1.0 header1 = mx.recordio.IRHeader(flag=0, label=label1, id=1, id2=0) s1 = mx.recordio.pack(header1, str_or_bytes(data)) label2 = [1.0, 2.0, 3.0] header2 = mx.recordio.IRHeader(flag=3, label=label2, id=2, id2=0) s2 = mx.recordio.pack(header2, str_or_bytes(data))
# unpack print(mx.recordio.unpack(s1)) print(mx.recordio.unpack(s2))
data = np.ones((3,3,1), dtype=np.uint8) label = 1.0 header = mx.recordio.IRHeader(flag=0, label=label, id=0, id2=0) s = mx.recordio.pack_img(header, data, quality=100, img_fmt='.jpg')
# unpack_img print(mx.recordio.unpack_img(s))
當咱們作計算機視頻方面的應用時,要處理的大部分數據都是圖像與視頻(也會拆成視頻幀處理)。因此咱們這個小節重點介紹在MXNet中是如何處理輸入數據爲圖像的場景的。
有4種方法可讓咱們選擇來把數據加載到MXNet中
mx.image.imdecode
來加載原始的圖像數據mx.img.ImageIter
它是用Python來實現的,比較靈活,方便咱們修改 ,它能夠讀取.rec的文件或者原始文件。mx.io.ImageRecordIter
它在MXNet中是放在後端用C++實現的,因此不太便於修改。mx.io.DataIter
寫一個本身的迭代器fname = mx.test_utils.download(url='http://data.mxnet.io/data/test_images.tar.gz', dirname='data', overwrite=False) tar = tarfile.open(fname) tar.extractall(path='./data') tar.close()
img = mx.image.imdecode(open('data/test_images/ILSVRC2012_val_00000001.JPEG', 'rb').read()) plt.imshow(img.asnumpy()); plt.show()
# resize to w x h tmp = mx.image.imresize(img, 100, 70) plt.imshow(tmp.asnumpy()); plt.show()
# crop a random w x h region from image tmp, coord = mx.image.random_crop(img, (150, 200)) print(coord) plt.imshow(tmp.asnumpy()); plt.show()
咱們先下載一個數據集,Caltech 101,它包含了101類物體。咱們先將它轉換成RecordIO格式文件。
fname = mx.test_utils.download(url='http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz', dirname='data', overwrite=False) tar = tarfile.open(fname) tar.extractall(path='./data') tar.close()
咱們先看一下這個數據集,在根目錄下(./data/101_ObjectCategories),每個類別都是一個子文件平。咱們可使用腳本im2rec.py
來將整個目錄轉化爲成ReecordIO文件。第一步,咱們把全部的圖片路徑以及它們的label列到一個文本中。
os.system('python %s/tools/im2rec.py --list=1 --recursive=1 --shuffle=1 --test-ratio=0.2 data/caltech data/101_ObjectCategories'%os.environ['MXNET_HOME'])
上面的命令會生成一個caltech_train.lst的文件,文件的內容是index\t(one or more label)\tpath
的格式。在這個例子中,只有一個label。而後咱們就能夠用這個文件列表信息雲生成咱們的RecordIO文件了。
os.system("python %s/tools/im2rec.py --num-thread=4 --pass-through=1 data/caltech data/101_ObjectCategories"%os.environ['MXNET_HOME'])
ImageRecordIter
能夠經過RecordIO格式來加載圖片數據。
data_iter = mx.io.ImageRecordIter( path_imgrec="./data/caltech.rec", # the target record file data_shape=(3, 227, 227), # output data shape. An 227x227 region will be cropped from the original image. batch_size=4, # number of samples per batch resize=256 # resize the shorter edge to 256 before cropping # ... you can add more augumentation options as defined in ImageRecordIter. ) data_iter.reset() batch = data_iter.next() data = batch.data[0] for i in range(4): plt.subplot(1,4,i+1) plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0))) plt.show()
除了ImageRecordIter
外,咱們可使用ImageIter
來讀取一個RecordIO文件或者直接讀取原始格式的文件。
data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227), path_imgrec="./data/caltech.rec", path_imgidx="./data/caltech.idx" ) data_iter.reset() batch = data_iter.next() data = batch.data[0] for i in range(4): plt.subplot(1,4,i+1) plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0))) plt.show()