MXnet的設計結構是C++作後端運算,python、R等作前端來使用,這樣既兼顧了效率,又讓使用者方便了不少,完整的使用MXnet訓練本身的數據集須要瞭解幾個方面。今天咱們先談一談Data iterators。前端
MXnet中的data iterator和python中的迭代器是很類似的, 當其內置方法next被call的時候它每次返回一個 data batch。所謂databatch,就是神經網絡的輸入和label,通常是(n, c, h, w)的格式的圖片輸入和(n, h, w)或者標量式樣的label。直接上官網上的一個簡單的例子來講說吧。python
1 import numpy as np 2 class SimpleIter: 3 def __init__(self, data_names, data_shapes, data_gen, 4 label_names, label_shapes, label_gen, num_batches=10): 5 self._provide_data = zip(data_names, data_shapes) 6 self._provide_label = zip(label_names, label_shapes) 7 self.num_batches = num_batches 8 self.data_gen = data_gen 9 self.label_gen = label_gen 10 self.cur_batch = 0 11 12 def __iter__(self): 13 return self 14 15 def reset(self): 16 self.cur_batch = 0 17 18 def __next__(self): 19 return self.next() 20 21 @property 22 def provide_data(self): 23 return self._provide_data 24 25 @property 26 def provide_label(self): 27 return self._provide_label 28 29 def next(self): 30 if self.cur_batch < self.num_batches: 31 self.cur_batch += 1 32 data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)] 33 assert len(data) > 0, "Empty batch data." 34 label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)] 35 assert len(label) > 0, "Empty batch label." 36 return SimpleBatch(data, label) 37 else: 38 raise StopIteration
上面的代碼是最簡單的一個dataiter了,沒有對數據的預處理,甚至於沒有本身去讀取數據,可是基本的意思是到了,一個dataiter必需要實現上面的幾個方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每一個epoch後打亂讀取圖片的順序,這樣隨機採樣的話訓練效果會好一點,通常狀況下是用shuffle你的lst(上篇用來讀取圖片的lst)實現的,next()的方法就很顯然了,用來返回你的databatch,若是出現問題...記得raise stopIteration,這裏或許用try更好吧...須要注意的是,databatch返回的數據類型是mx.nd.ndarry。後端
下面是我最近作segmentation的時候用的一個稍微複雜的dataiter,多了預處理和shuffle等步驟:服務器
1 # pylint: skip-file 2 import random 3 4 import cv2 5 import mxnet as mx 6 import numpy as np 7 import os 8 from mxnet.io import DataIter, DataBatch 9 10 11 class FileIter(DataIter): #通常都是繼承DataIter 12 """FileIter object in fcn-xs example. Taking a file list file to get dataiter. 13 in this example, we use the whole image training for fcn-xs, that is to say 14 we do not need resize/crop the image to the same size, so the batch_size is 15 set to 1 here 16 Parameters 17 ---------- 18 root_dir : string 19 the root dir of image/label lie in 20 flist_name : string 21 the list file of iamge and label, every line owns the form: 22 index \t image_data_path \t image_label_path 23 cut_off_size : int 24 if the maximal size of one image is larger than cut_off_size, then it will 25 crop the image with the minimal size of that image 26 data_name : string 27 the data name used in symbol data(default data name) 28 label_name : string 29 the label name used in symbol softmax_label(default label name) 30 """ 31 32 def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117), 33 data_name="data", label_name="softmax_label", p=None): 34 super(FileIter, self).__init__() 35 36 self.fac = p.fac #這裏的P是本身定義的config 37 self.root_dir = root_dir 38 self.flist_name = os.path.join(self.root_dir, flist_name) 39 self.mean = np.array(rgb_mean) # (R, G, B) 40 self.data_name = data_name 41 self.label_name = label_name 42 self.batch_size = p.batch_size 43 self.random_crop = p.random_crop 44 self.random_flip = p.random_flip 45 self.random_color = p.random_color 46 self.random_scale = p.random_scale 47 self.output_size = p.output_size 48 self.color_aug_range = p.color_aug_range 49 self.use_rnn = p.use_rnn 50 self.num_hidden = p.num_hidden 51 if self.use_rnn: 52 self.init_h_name = 'init_h' 53 self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden)) 54 self.cursor = -1 55 56 self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1])) 57 self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac)) 58 self.data_list = [] 59 self.label_list = [] 60 self.order = [] 61 self.dict = {} 62 lines = file(self.flist_name).read().splitlines() 63 cnt = 0 64 for line in lines: #讀取lst,爲後面讀取圖片作好準備 65 _, data_img_name, label_img_name = line.strip('\n').split("\t") 66 self.data_list.append(data_img_name) 67 self.label_list.append(label_img_name) 68 self.order.append(cnt) 69 cnt += 1 70 self.num_data = cnt 71 self._shuffle() 72 73 def _shuffle(self): 74 random.shuffle(self.order) 75 76 def _read_img(self, img_name, label_name): 77 # 這個是在服務器上跑的時候,由於數據集很小,並且常常被同事卡IO,因此我就把數據所有放進了內存 78 if os.path.join(self.root_dir, img_name) in self.dict: 79 img = self.dict[os.path.join(self.root_dir, img_name)] 80 else: 81 img = cv2.imread(os.path.join(self.root_dir, img_name)) 82 self.dict[os.path.join(self.root_dir, img_name)] = img 83 84 if os.path.join(self.root_dir, label_name) in self.dict: 85 label = self.dict[os.path.join(self.root_dir, label_name)] 86 else: 87 label = cv2.imread(os.path.join(self.root_dir, label_name),0) 88 self.dict[os.path.join(self.root_dir, label_name)] = label 89 90 91 # 下面是讀取圖片後的一系統預處理工做 92 if self.random_flip: 93 flip = random.randint(0, 1) 94 if flip == 1: 95 img = cv2.flip(img, 1) 96 label = cv2.flip(label, 1) 97 # scale jittering 98 scale = random.uniform(self.random_scale[0], self.random_scale[1]) 99 new_width = int(img.shape[1] * scale) # 680 100 new_height = int(img.shape[0] * scale) # new_width * img.size[1] / img.size[0] 101 img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST) 102 label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST) 103 #img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST) 104 #label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST) 105 if self.random_crop: 106 start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1) 107 start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1) 108 img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :] 109 label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]] 110 if self.random_color: 111 img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 112 hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0]) 113 sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1]) 114 val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2]) 115 img = np.array(img, dtype=np.float32) 116 img[..., 0] += hue 117 img[..., 1] += sat 118 img[..., 2] += val 119 img[..., 0] = np.clip(img[..., 0], 0, 255) 120 img[..., 1] = np.clip(img[..., 1], 0, 255) 121 img[..., 2] = np.clip(img[..., 2], 0, 255) 122 img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR) 123 is_rgb = True 124 #cv2.imshow('main', img) 125 #cv2.waitKey() 126 #cv2.imshow('maain', label) 127 #cv2.waitKey() 128 img = np.array(img, dtype=np.float32) # (h, w, c) 129 reshaped_mean = self.mean.reshape(1, 1, 3) 130 img = img - reshaped_mean 131 img[:, :, :] = img[:, :, [2, 1, 0]] 132 img = img.transpose(2, 0, 1) 133 # img = np.expand_dims(img, axis=0) # (1, c, h, w) 134 135 label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac) 136 label_zoomed = label_zoomed.astype('uint8') 137 return (img, label_zoomed) 138 139 @property 140 def provide_data(self): 141 """The name and shape of data provided by this iterator""" 142 if self.use_rnn: 143 return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])), 144 (self.init_h_name, (self.batch_size, self.num_hidden))] 145 else: 146 return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))] 147 148 @property 149 def provide_label(self): 150 """The name and shape of label provided by this iterator""" 151 return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))] 152 153 def get_batch_size(self): 154 return self.batch_size 155 156 def reset(self): 157 self.cursor = -self.batch_size 158 self._shuffle() 159 160 def iter_next(self): 161 self.cursor += self.batch_size 162 return self.cursor < self.num_data 163 164 def _getpad(self): 165 if self.cursor + self.batch_size > self.num_data: 166 return self.cursor + self.batch_size - self.num_data 167 else: 168 return 0 169 170 def _getdata(self): 171 """Load data from underlying arrays, internal use only""" 172 assert(self.cursor < self.num_data), "DataIter needs reset." 173 data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1])) 174 label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac)) 175 if self.cursor + self.batch_size <= self.num_data: 176 for i in range(self.batch_size): 177 idx = self.order[self.cursor + i] 178 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx]) 179 data[i] = data_ 180 label[i] = label_ 181 else: 182 for i in range(self.num_data - self.cursor): 183 idx = self.order[self.cursor + i] 184 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx]) 185 data[i] = data_ 186 label[i] = label_ 187 pad = self.batch_size - self.num_data + self.cursor 188 #for i in pad: 189 for i in range(pad): 190 idx = self.order[i] 191 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx]) 192 data[i + self.num_data - self.cursor] = data_ 193 label[i + self.num_data - self.cursor] = label_ 194 return mx.nd.array(data), mx.nd.array(label) 195 196 def next(self): 197 """return one dict which contains "data" and "label" """ 198 if self.iter_next(): 199 data, label = self._getdata() 200 data = [data, self.init_h] if self.use_rnn else [data] 201 label = [label] 202 return DataBatch(data=data, label=label, 203 pad=self._getpad(), index=None, 204 provide_data=self.provide_data, 205 provide_label=self.provide_label) 206 else: 207 raise StopIteration
到這裏基本上正常的訓練咱們就能夠開始了,可是當你有了不少新的想法的時候,你又會遇到新的問題...好比:multi input/output怎麼辦?網絡
其實也很簡單,只須要修改幾個地方:app
一、provide_label和provide_data,注意到以前咱們的return都是一個list,因此之間在裏面添加和以前同樣的格式就好了。dom
2. next() 若是你須要傳 data和depth兩個輸入,只須要傳 input = sum([[data],[depth],[]])到databatch的data就好了,label也同理。ide
值得一提的時候,MXnet的multi loss實現起來須要在寫network的symbol的時候注意一點,假設你有softmax_loss和regression_loss。那麼只要在最後return mx.symbol.Group([softmax_loss, regression_loss])。ui
總之......That's all~~~~this