首先介紹數據讀取問題,如今TensorFlow官方推薦的數據讀取方法是使用tf.data.Dataset,具體的細節不在這裏贅述,看官方文檔更清楚,這裏主要記錄一下官方文檔沒有提到的坑,以示"後人"。由於是記錄踩過的坑,因此行文混亂,見諒。html
不感興趣的可跳過此節。python
最近在研究ENAS的代碼,這個網絡的做用是基於加強學習,可以自動生成合適的網絡結構。原做者使用TensorFlow在cifar10上成功自動生成了網絡結構,並取得了不錯的效果。git
但問題來了,此時我須要將代碼轉移到本身的數據集上,都知道cifar10圖像大小是32*32,並非特別大,因此原做者"喪心病狂"
地採用了一次性將數據讀進顯存的操做,絲絕不考慮我等渣渣的感覺。個人數據集原圖基本在500*800或以上,通過反覆試驗,若是採用源代碼我必須將圖像經過縮放和中心裁剪到160*160才能正常運行,並且運行結果並非很理想,十分類跑了一天左右最好的結果才30%左右。github
我在想若是把圖片放大後是否會提升準確度,因此第一個坑是修改數據讀取方式,適應大數據集讀取。網絡
再仔細閱讀源代碼後我還發現做者使用了tf.train.shuffle_batch
這個函數用來批量讀取,這個函數也讓我頭疼了好久,由於一直不知道它和tf.data.Dataset.batch.shuffle()
有什麼區別,因此第二個坑時tf.train.shuffle_batch
和tf.data.Dataset.batch.shuffle()
到底什麼關係(區別)ide
tf.train.batch
和tf.data.Dataset.batch.shuffle()
什麼區別其實這兩個談不上什麼區別,由於後者是前者的升級版,233333。函數
官方文檔對tf.train.batch
的描述是這樣的:學習
THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Queue-based input pipelines have been replaced by tf.data. Use tf.data.Dataset.batch(batch_size) (or padded_batch(...) if dynamic_pad=True).大數據
在這裏我也推薦你們用tf.data,由於他相比於原來的tf.train.batch好用太多。ui
這裏的大數據集指的是稍微比較大的,像ImageNet這樣的數據集還沒嘗試過。因此下面的方法不敢確定是否使用於ImageNet。
要想讀取大數據集,我找到的官方給出的方案有兩種:
個人數據集是以已經分好類的文件夾進行存儲的,大體結構是這樣的
├───test │ ├───Acne_Vulgaris │ ├───Actinic_solar_Damage__Actinic_Keratosis │ ├───Basal_Cell_Carcinoma │ ├───Rosacea │ └───Seborrheic_Keratosis ├───train │ ├───Acne_Vulgaris │ ├───Actinic_solar_Damage__Actinic_Keratosis │ ├───Basal_Cell_Carcinoma │ ├───Rosacea │ └───Seborrheic_Keratosis └───valid ├───Acne_Vulgaris ├───Actinic_solar_Damage__Actinic_Keratosis ├───Basal_Cell_Carcinoma ├───Rosacea └───Seborrheic_Keratosis
個人方法很是適合懶人,具體流程以下:
pytorch提供了torchvision這個庫,這個庫堪稱瑰寶,torchvision.datasets裏有個函數是ImageFolder,你只須要指明路徑便可把圖片數據都讀進來,不用再苦逼地手寫for循環遍歷了。其餘的細節,好比data augmentation等等就不介紹了,具體代碼可參看官方文檔以及以下連接: https://github.com/marsggbo/enas/blob/master/src/skin5_placeholder/data_utils.py
假設上一步已經圖像數據讀取完畢,並保存成numpy文件,下面參看官方文檔例子
# 讀取numpy數據 with np.load("/var/data/training_data.npy") as data: features = data["features"] labels = data["labels"] # 查看圖像和標籤維度是否保持一致 assert features.shape[0] == labels.shape[0] # 建立placeholder features_placeholder = tf.placeholder(features.dtype, features.shape) labels_placeholder = tf.placeholder(labels.dtype, labels.shape) # 建立dataset dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder)) # 批量讀取,打散數據,repeat() dataset = dataset.shuffle(20).batch(5).repeat() # [Other transformations on `dataset`...] dataset_other = ... iterator = dataset.make_initializable_iterator() data_element = iterator.get_nex() sess = tf.Session() sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) for e in range(EPOCHS): for step in range(num_batches): x_batch, y_batch = sess.run(data_element) y_pred = model(x_batch) ... ... sess.close()
插播一條廣告:上面代碼中batch(), shuffle(), repeat()的具體用法參見Tensorflow datasets.shuffle repeat batch方法。
上面邏輯很清楚:
注意,每次一運行sess.run(data_element)這個語句,TensorFlow會自動的調取下一個批次的數據。不只如此,只要sess.run一個把data_element做爲輸入的節點,也都會自動調取下一個批次的數據。說的有點繞,看例子就明白了
能夠看到若是在讀取數據的時候還sess.run與數據有關的操做,那麼有的數據就根本沒遍歷到,因此這個問題要特別注意。
那我爲何會連這種坑都能踩到呢,由於原做者的代碼寫的太「好」了,對於我這種剛入門的人來講太難理解和修改了。
原做者的代碼結構並無寫for循環遍歷讀取數據,而後傳入到模型。相反他把數據操做寫到了另外一個類(文件)中,好比說在model.py
中他定義了
class Model(): def __init__(): ... def _model(self, img, label): y_pred = other_function(img) acc = calculate_acc(y_pred, label) ...
而後在main.py
中他只是sess.run(model.acc),即
with tf.Session() as sess: ... while epoch < EPOCHS: global_step = sess.run(model.global_step) if global_step % 50: acc = sess.run(model.acc) ... ...
抱怨一下: 它這代碼結構寫得和官方文檔不同,因此一直不知道怎麼修改。你若是從最開始看到這,你應該以爲很好改啊,可是你看着官方文檔真不知道怎麼修改,由於最開始我並不知道每次sess.run以後都會自動調用下一個batch的數據,並且也尚未習慣TensorFlow數據流的思惟。在這裏特別感謝這個問題幫助我解答了困惑:Tensorflow: create minibatch from numpy array > 2 GB。
因此這種狀況怎麼讀取數據呢?很簡單,只須要在循環語句以前初始化迭代器便可。
ops = { "global_step": model.global_step, "acc": model.acc } with tf.Session() as sess: ... sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) while epoch < EPOCHS: global_step = sess.run(ops['global_step']) if global_step % 50: acc = sess.run(ops['acc']) ... ...
若是你想要查看數據是否正確讀取,千萬不要在上面的while循環中加入這麼一行代碼x_batch, y_batch=sess.run([model.x_batch, model.y_batch])
,這樣就會致使上面所說的數據沒法完整遍歷的問題。那怎麼辦呢?
咱們能夠考慮修改ops
來獲取數據,代碼以下:
ops = { "global_step": model.global_step, "acc": model.acc, "x_batch": model.x_batch, "y_batch": model.y_batch } with tf.Session() as sess: ... sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) while epoch < EPOCHS: global_step = sess.run(ops['global_step']) if global_step % 50: acc = sess.run([ops["acc"], ops["x_batch"], ops["y_batch"]]) ...
這樣之因此能完整遍歷,是由於咱們將x_batch和acc放在一塊兒啦~,因此這能夠當作只是一個運算。