圖像的亮度、對比度等屬性對圖像的影響是很是大的,然而在不少圖像識別問題中,這些因素都不該該影響最後的識別結果,因此在訓練模型以前,須要對圖像數據進行預處理,使訓練獲得的模型儘量小地被無關因素影響。python
7.1 TFRecord輸入數據格式git
7.1.1. TFRecord 格式介紹正則表達式
7.1.2 TFRecord 樣例程序算法
把mnist數據保存爲tfrecord格式:windows
1 #!coding:utf8 2 3 import tensorflow as tf 4 from tensorflow.examples.tutorials.mnist import input_data 5 import numpy as np 6 7 def _int64_feature(value): 8 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 9 10 def _bytes_feature(value): 11 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 12 13 mnist = input_data.read_data_sets('D:\\files\\tf\mnist', one_hot=True) # 肯定label是由0、1組成的數組,仍是單個整數。 14 images = mnist.train.images # (55000, 784) 15 labels = mnist.train.labels 16 17 pixels = images.shape[1] # 784 18 num_examples = mnist.train.num_examples 19 20 filename = 'D:\\files\\tf\yangxl.tfrecords' 21 writer = tf.python_io.TFRecordWriter(filename) 22 for index in range(num_examples): 23 image_raw = images[index].tostring() 24 example = tf.train.Example(features=tf.train.Features(feature={ 25 'pixels': _int64_feature(pixels), 26 'label': _int64_feature(np.argmax(labels[index])), 27 'image_raw': _bytes_feature(image_raw) 28 })) 29 writer.write(example.SerializeToString()) 30 writer.close()
讀取tfrecord文件:數組
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 import numpy as np 4 5 reader = tf.TFRecordReader() 6 7 # 建立一個隊列來維護輸入文件列表 8 file_queue = tf.train.string_input_producer(['/home/yangxl/files/mnist.tfrecords']) 9 10 # 從文件中讀取一個樣例; 一次性讀取多個樣例使用read_up_to函數 11 _, serialized_example = reader.read(file_queue) # tensor 12 # 解析樣例; 一次性解析多個樣例使用parse_example函數 13 features = tf.parse_single_example( 14 serialized_example, 15 features={ 16 # """ 17 # tf提供了兩種屬性解析方法,一種是定長tf.FixedLenFeature,解析結果爲一個Tensor; 18 # 另外一種是變長tf.VarLenFeature,解析結果爲SparseTensor,用於處理稀疏數據。 19 # 這裏解析數據的格式須要和寫入數據的格式一致。 20 # """ 21 # 使用多行註釋會報錯:`KeyError: 'pixels'`, 我擦淚... 22 23 # 解析時的鍵須要與保存時的鍵一致 24 'pixels': tf.FixedLenFeature([], tf.int64), 25 'label': tf.FixedLenFeature([], tf.int64), 26 'image_raw': tf.FixedLenFeature([], tf.string) 27 } 28 ) 29 30 # decode_raw能夠把字符串解析爲圖像對應的像素數組 31 # cast轉換數據類型 32 image = tf.decode_raw(features['image_raw'], tf.uint8) 33 label = tf.cast(features['label'], tf.int32) 34 pixels = tf.cast(features['pixels'], tf.int32) 35 36 with tf.Session() as sess: 37 coord = tf.train.Coordinator() 38 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 39 40 for i in range(15): # 每次執行sess.run()都會從隊列中取出一個樣例,這樣就會致使以後處理時可能不是同一個樣例,沒注意這個問題,這兩天就卡在這上面了 41 image_value, label_value, pixels_value = sess.run([image, label, pixels]) 42 print(label_value) 43 44 # 可視化, 可視化以前須要把一維數組轉爲二維數組 45 image_value = np.reshape(image_value, [28, 28]) 46 plt.imshow(image_value) 47 plt.show()
7.2 圖像數據處理網絡
一張RGB圖像能夠當作一個三維矩陣,矩陣中的每一個數字表示圖像上不一樣位置、不一樣顏色的亮度。圖像在存儲時,並非直接記錄這些矩陣中的數字,而是記錄通過壓縮編碼以後的結果。因此要將一張圖像還原成一個三維矩陣,須要解碼的過程。TF提供了對jpeg、png格式圖像的編碼/解碼函數。數據結構
圖像編碼、解碼多線程
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 4 with tf.gfile.GFile('/home/error/cat.jpg', 'rb') as f: 5 image_raw_data = f.read() 6 # 對jpeg格式的圖像進行解碼, 獲得圖像對應的三維矩陣, 獲得一個tensor 7 image_data = tf.image.decode_jpeg(image_raw_data) # 獲得一個tensor, (1797, 2673, 3) dtype=uint8 8 9 # 編碼 10 encoded_image = tf.image.encode_jpeg(image_data) # 獲得一個tensor 11 12 13 with tf.Session() as sess: 14 plt.imshow(sess.run(image_data)) 15 plt.show() 16 17 with tf.gfile.GFile('/home/error/cat_bk.jpg', 'wb') as f: 18 f.write(sess.run(encoded_image))
圖像大小調整app
圖像大小是不固定的,但神經網絡輸入節點的個數是固定的,因此在將圖像的像素做爲輸入提供給神經網絡以前,須要先將圖像的大小統一。
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 4 with tf.gfile.FastGFile('/home/error/cat.jpg', 'rb') as f: 5 image_raw_data = f.read() 6 # 對jpeg格式的圖像進行解碼, 獲得圖像對應的三維矩陣, 獲得一個tensor 7 image_data = tf.image.decode_jpeg(image_raw_data) # (1797, 2673, 3) 8 9 # 在圖像處理以前將圖像由uint8轉爲實數類型 10 img_data = tf.image.convert_image_dtype(image_data, dtype=tf.float32) 11 12 resized_image = tf.image.resize_images(img_data, [300, 300], method=0) # method取值爲0~3 13 14 with tf.Session() as sess: 15 print(resized_image) 16 17 plt.imshow(sess.run(resized_image)) 18 plt.show()
裁剪和填充,居中
1 # 裁剪, 若是原始圖像的尺寸大於目標圖像, 會自動截取原始圖像居中的部分 2 croped = tf.image.resize_image_with_crop_or_pad(img_data, 1000, 1000) 3 # 填充, 若是目標圖像的尺寸大於原始圖像, 會自動在原始圖像的四周填充全0背景 4 paded = tf.image.resize_image_with_crop_or_pad(img_data, 3000, 3000)
按比例裁剪,居中
1 # 按比例截取原始圖像居中的部分, 比例爲(0, 1]之間的實數 2 central_cropped = tf.image.central_crop(img_data, 0.5)
在指定區域進行裁剪和填充
1 # 裁剪給定區域的圖像, 該函數對給出的尺寸有必定的要求, 不然報錯 2 croped_bound = tf.image.crop_to_bounding_box(img_data, 500, 500, 800, 800) 3 4 # 填充, 圖像從(500, 500)開始, 左側和上側全0背景, 顯示圖像後, 繼續是全0背景。該函數對給出的尺寸有必定的要求, 不然報錯 5 paded_bound = tf.image.pad_to_bounding_box(img_data, 500, 500, 2500, 3500) # offset_height + img_height < target_height
圖像翻轉
圖像的翻轉不該該影響識別的效果,所以在訓練圖像識別神經網絡時,能夠隨機地翻轉訓練圖像,這樣訓練獲得的模型能夠識別不一樣角度的實體。
1 # 上下翻轉 2 flipped = tf.image.flip_up_down(img_data) 3 # 左右翻轉 4 flipped = tf.image.flip_left_right(img_data) 5 # 沿對角線翻轉, 主對角線 6 flipped = tf.image.transpose_image(img_data) 7 8 # 以50%的機率上下翻轉圖像 9 flipped = tf.image.random_flip_up_down(img_data) 10 # 以50%的機率左右翻轉圖像 11 flipped = tf.image.random_flip_left_right(img_data)
圖像色彩的調整
調整圖像的亮度、對比度、飽和度和色相都不會影響識別結果,所以能夠隨機地調整這些屬性
調整亮度
1 # 調整亮度, 負號是調暗, -1爲黑屏, 正數是調亮, 1爲白屏 2 adjusted = tf.image.adjust_brightness(img_data, -0.5) 3 4 # 在[-max_delta, max_delta]範圍內隨機調整圖像亮度 5 adjusted = tf.image.random_brightness(img_data, 1) 6 7 8 # 截斷 9 # 色彩調整的API可能致使像素的實數值超出0.0~1.0的範圍,所以在最終輸出圖像前須要將其截斷在0.0~1.0範圍內, 10 # 不然不只圖像不能正常可視化,以此爲輸入的神經網絡的訓練質量也可能會受到影響。 11 # 若是對圖像有多項處理,那麼截斷應該在全部處理完成以後進行 12 adjusted = tf.clip_by_value(adjusted, 0.0, 1.0)
調整對比度
1 # 調整對比度,將對比度減小到0.5倍 2 adjusted = tf.image.adjust_contrast(img_data, 0.5) 3 # 調整對比度,將對比度增長5倍 4 adjusted = tf.image.adjust_contrast(img_data, 5) 5 # 將對比度在[0.5, 5]範圍內隨機調整 6 adjusted = tf.image.random_contrast(img_data, 0.5, 5)
調整色相
1 # 分別取值[0.1, 0.3, 0.6, 0.9], 色彩從綠變爲藍,又變爲紅 2 adjusted = tf.image.adjust_hue(img_data, 0.9) 3 # 取值在[0.0, 0.5]以前 4 adjusted = tf.image.random_hue(0, 0.8)
調整飽和度
1 # 調整飽和度 2 adjusted = tf.image.adjust_saturation(img_data, -5) # 飽和度-5(+5就是加5) 3 # 在[-5, 5]範圍內隨機調整飽和度 4 tf.image.random_saturation(img_data, -5, 5)
注意:對於色相、飽和度,須要輸入數據的channels爲3,例如mnist數據就不行,亮度、對比度沒有限制。
將圖像標準化
即將圖像的亮度均值變爲0,方差變爲1
1 adjusted = tf.image.per_image_standardization(img_data)
處理標註框
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 4 with tf.gfile.GFile('/home/error/cat.jpg', 'rb') as f: 5 image_raw_data = f.read() 6 # 對jpeg格式的圖像進行解碼, 獲得圖像對應的三維矩陣, 獲得一個tensor 7 decoded_image_data = tf.image.decode_jpeg(image_raw_data) # (1797, 2673, 3) 8 9 # 在圖像處理以前將圖像由uint8轉爲實數類型 10 converted_image_data = tf.image.convert_image_dtype(decoded_image_data, dtype=tf.float32) 11 12 # 把圖像縮小一些,讓標註框更清楚 13 resized_image_data = tf.image.resize_images(converted_image_data, [180, 267]) 14 15 # 輸入是一個batch的數據,也就是多張圖片組成的四維矩陣,因此須要加1個維度 16 expanded_image_data = tf.expand_dims(resized_image_data, 0) # (1, 180, 267, ?) 17 # 標註框,數值爲比例,秩爲3,設置了兩個標註框,爲啥秩爲3呢?? 18 boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]]) 19 drawn_image_data = tf.image.draw_bounding_boxes(expanded_image_data, boxes) 20 21 adjusted = tf.clip_by_value(drawn_image_data[0], 0.0, 1.0) 22 23 with tf.Session() as sess: 26 27 plt.imshow(sess.run(adjusted)) 28 plt.show()
隨機截取圖像
隨機截取圖像上有信息含量的部分也是一種提升模型健壯性的方式。這樣可使訓練獲得的模型不受識別物體大小的影響。
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 4 with tf.gfile.FastGFile('/home/yangxl/files/cat.jpg', 'rb') as f: 5 image_raw_data = f.read() 6 # 對jpeg格式的圖像進行解碼, 獲得圖像對應的三維矩陣, 獲得一個tensor 7 decoded_image_data = tf.image.decode_jpeg(image_raw_data) # (1797, 2673, 3) 8 # 在圖像處理以前將圖像由uint8轉爲實數類型 9 converted_image_data = tf.image.convert_image_dtype(decoded_image_data, dtype=tf.float32) 10 # 把圖像縮小一些,讓標註框更清楚 11 resized_image_data = tf.image.resize_images(converted_image_data, [180, 267], method=1) 12 13 # 輸入是一個batch的數據,也就是多張圖片組成的四維矩陣,因此須要加1個維度 14 expanded_image_data = tf.expand_dims(resized_image_data, 0) # (1, 180, 267, ?) 15 # 標註框,數值爲比例,秩爲3,設置了兩個標註框 16 boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]]) 17 18 # 擴維以前的shape 19 begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( 20 tf.shape(resized_image_data), bounding_boxes=boxes, 21 min_object_covered=0.4 22 ) 23 image_with_box = tf.image.draw_bounding_boxes(expanded_image_data, bbox_for_draw) 24 25 adjusted = tf.clip_by_value(image_with_box[0], 0.0, 1.0) 26 distorted_image = tf.slice(adjusted, begin, size) 27 28 with tf.Session() as sess: 29 plt.imshow(sess.run(distorted_image)) # 若是不使用slice, 像這樣plt.imshow(sess.run(adjusted)), 可視化結果爲不截取只隨機標註 30 plt.show() 31 # 其實就多了兩行: 32 # sample_distorted_bounding_box和slice
7.2.2 圖像預處理完整樣例
由於調整亮度、對比度、飽和度和色相的順序會影響最後獲得的結果,因此能夠定義多種不一樣的順序。具體使用哪種能夠隨機選定。這樣能夠進一步下降無關因素對模型的影響。
1 import tensorflow as tf 2 import numpy as np 3 import matplotlib.pyplot as plt 4 5 def distort_color(image, color_ordering=0): 6 if color_ordering == 0: 7 image = tf.image.random_brightness(image, max_delta=32./255.) 8 image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 9 image = tf.image.random_hue(image, max_delta=0.2) 10 image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 11 elif color_ordering == 1: 12 image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 13 image = tf.image.random_brightness(image, max_delta=32. / 255.) 14 image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 15 image = tf.image.random_hue(image, max_delta=0.2) 16 return tf.clip_by_value(image, 0.0, 1.0) 17 18 # 給定一張解碼後的圖像、目標圖像的尺寸以及圖像上的標註框,此函數能夠對給出的圖像進行預處理。 19 # 這個函數的輸入圖像是圖像識別問題中的原始訓練數據,而輸出是神經網絡模型的輸入層。 20 # 注意,只處理模型的訓練數據,對於預測數據,通常不須要使用隨機變換的步驟。 21 def preprocessed_for_train(image, height, width, bbox): 22 # 若是沒有提供標註框,則認爲整個圖像就是須要關注的部分。 23 if bbox is None: 24 bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) # 秩爲3 25 26 # 轉換圖像張量的類型 27 if image.dtype != tf.float32: 28 image = tf.image.convert_image_dtype(image, tf.float32) 29 30 # 隨機截取圖像,減小須要關注的物體大小對圖像識別算法的影響。 31 bbox_begin, bbox_size, bbox_for_draw = tf.image.sample_distorted_bounding_box(tf.shape(image), bounding_boxes=bbox) 32 distorted_image = tf.slice(image, bbox_begin, bbox_size) 33 34 # 調整大小 35 distorted_image = tf.image.resize_images(distorted_image, [height, width], method=np.random.randint(4)) 36 37 # 隨機翻轉 38 distorted_image = tf.image.random_flip_left_right(distorted_image) 39 40 # 調整色彩 41 distorted_image = distort_color(distorted_image, np.random.randint(2)) 42 43 return distorted_image 44 45 46 image_raw_data = tf.gfile.GFile('/home/yangxl/files/cat.jpg', 'rb').read() 47 with tf.Session() as sess: 48 img_data = tf.image.decode_jpeg(image_raw_data) 49 boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]]) 50 51 # 運行6次得到6中不一樣的圖像 52 for i in range(6): 53 result = preprocessed_for_train(img_data, 299, 299, boxes) 54 plt.imshow(result.eval()) 55 plt.show()
7.3 多線程輸入數據處理框架
雖然圖像預處理方法能夠減少無關因素對圖像識別模型效果的影響,可是這些複雜的預處理過程也會減慢整個訓練過程。爲了不圖像預處理成爲神經網絡模型訓練效率的瓶頸,tensorflow提供了一套多線程處理輸入數據的框架。
tensorflow中,隊列不只是一種數據結構,更提供了多線程機制。隊列也是tensorflow多線程輸入數據處理框架的基礎。
7.3.1介紹隊列和多線程
7.3.2介紹前三步
7.3.3介紹最後一步
7.3.4完整示例
7.4介紹數據集
7.3.1 隊列與多線程
隊列和變量相似,都是計算圖上有狀態的節點。其餘狀態節點能夠修改它們的狀態。
隊列操做:
1 import tensorflow as tf 2 3 # 先進先出隊列 4 q = tf.FIFOQueue(2, 'int32') 5 # 隊列初始化 6 init = q.enqueue_many([[0, 10], ]) # 這個至少要有兩層括號,不然報錯:Shape () must have rank at least 1 7 8 x = q.dequeue() 9 y = x + 1 10 q_inc = q.enqueue([y]) # 能夠沒有括號 11 12 with tf.Session() as sess: 13 init.run() # 隊列初始化須要明確調用 14 for i in range(5): 15 # 10, 1 1, 11 11, 2 2, 12 12, 3 16 sess.run(q_inc) 17 18 print(sess.run(x)) # 12 19 print(sess.run(x)) # 3
tf提供了FIFOQueue和RandomShuffleQueue兩種隊列。FIFOQueue是先進先出隊列,RandomShuffleQueue會將隊列中的元素打亂,每次出隊列操做獲得的是從當前隊列全部元素中隨機選擇的一個。在訓練神經網絡時但願每次使用的訓練數據儘可能隨機,RandomShuffleQueue就提供了這樣的功能。
tf提供了tf.train.Coordinator和tf.QueueRunner兩個類來完成多線程協同的功能。
tf.train.Coordinator主要用於協同多個線程一塊兒中止,並提供了should_stop、request_stop和join三個函數。在啓動線程以前,須要先聲明一個tf.train.Coordinator類,並將這個類傳入每一個建立的線程中。啓動的線程須要一直查詢tf.Coordinator類中提供的should_stop函數,當這個函數的返回值爲True時,則當前線程退出。每一個線程均可以經過調用request_stop函數來通知其餘線程退出,即當某一個線程調用request_stop函數以後,should_stop函數的返回值被設置爲True,這樣其餘線程就能夠同時退出了。
1 import tensorflow as tf 2 import numpy as np 3 import threading 4 import time 5 6 def MyLoop(coord, worker_id): 7 while not coord.should_stop(): 8 if np.random.rand() < 0.05: 9 print('Stoping from id: %d\n' % worker_id) 10 coord.request_stop() 11 else: 12 print('working on id: %d\n' % worker_id) 13 time.sleep(1) 14 15 16 coord = tf.train.Coordinator() 17 threads = [threading.Thread(target=MyLoop, args=(coord, i)) for i in range(5)] 18 19 for t in threads: 20 t.start() 21 22 # 等待全部線程退出 23 coord.join(threads)
tf.train.QueueRunner主要用於啓動多個線程來操做同一個隊列。啓動的線程能夠經過tf.Coordinator類來統一管理。
1 queue = tf.FIFOQueue(100, 'float') 2 enqueue_op = queue.enqueue([tf.random_normal([1])]) 3 4 # 啓動5個線程來操做隊列,每一個線程中運行的是enqueue_op 5 qr = tf.train.QueueRunner(queue, [enqueue_op] * 5) 6 # 將qr加入到計算圖指定的集合中,若是沒有指定集合則默認加到tf.GraphKeys.QUEUE_RUNNERS 7 tf.train.add_queue_runner(qr) 8 9 out_tensor = queue.dequeue() 10 11 with tf.Session() as sess: 12 coord = tf.train.Coordinator() 13 # 使用tf.train.QueueRunner時,須要明確調用tf.train.start_queue_runners來啓動全部線程。 14 # tf.train.start_queue_runners會默認啓動tf.GraphKeys.QUEUE_RUNNERS集合中的全部QueueRunner。 15 # 由於這個函數只支持啓動指定集合中的QueueRunner,因此tf.train.add_queue_runner和tf.train.start_queue_runners會指定同一個集合。 16 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 17 18 for i in range(3): 19 print(sess.run(out_tensor)) 20 21 coord.request_stop() 22 coord.join(threads)
7.3.2 輸入文件隊列
使用TF中的隊列管理輸入文件列表。
雖然一個TFRecord文件中能夠保存多個訓練樣例,可是當訓練數據量較大時,能夠將數據分紅多個TFRecord文件來提升處理效率。tensorflow提供了tf.train.match_filenames_once函數來獲取符合一個正則表達式的全部文件,獲得的文件列表能夠經過tf.train.string_input_producer函數進行有效的管理。注意,在使用tf.train.match_filenames_once時須要初始化一些變量,tf.local_variables_initizliaer().run()。
tf.train.string_input_producer函數會使用初始化時提供的文件列表建立一個輸入隊列,輸入隊列中原始的元素爲文件列表中的全部文件,建立好的輸入隊列能夠做爲文件讀取函數的參數。
每次調用文件讀取函數時,該函數會先判斷當前是否已有打開的文件可讀,若是沒有或者打開的文件已經讀完,這個函數就會從輸入隊列中出隊一個文件並從這個文件中讀取數據。
1 reader = tf.TFRecordReader() 2 # 建立輸入隊列 3 filename_queue = tf.train.string_input_producer(['/home/error/output.tfrecords']) 4 # 讀取樣例 5 _, serializd_example = reader.read(filename_queue)
經過設置shuffle參數,tf.train.string_input_producer函數支持隨機打亂文件列表中文件出隊的順序。隨機打亂文件順序以及加入輸入隊列的過程會跑在一個單獨的線程上,這樣不會影響獲取文件的速度。
tf.train.string_input_producer生成的輸入隊列能夠同時被多個文件讀取線程操做,並且輸入隊列會將隊列中的文件均勻地分配給不一樣的線程,不出現有些文件被處理屢次而有些文件還沒被處理的狀況。
當一個輸入隊列中的全部文件都被處理後,它會將初始化時提供的文件列表中的文件所有從新加入隊列。能夠經過設置num_epochs參數來限制加載初始文件列表的最大輪數。當全部文件都已經被使用了設定的輪數後,若是繼續嘗試讀取新的文件,輸入隊列會報錯:OutOfRange。在測試神經網絡時,由於全部測試數據只須要使用一次,全部能夠將num_epochs設置爲1,這樣在計算完一輪以後程序將自動中止。
在展現tf.train.match_filenames_once和tf.train.string_input_producer函數的使用方法以前,先生成兩個TFRecords文件,
1 num_shards = 2 2 instances_per_shard = 2 3 4 def _int64_feature(value): 5 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 6 7 # 生成兩個文件,每一個文件保存2個樣例 8 for i in range(num_shards): 9 filename = '/home/yangxl/files/tfrecords/data.tfrecords-%.5d-of-%.5d' % (i, num_shards) # 書上是帶括號的,('...') 10 writer = tf.python_io.TFRecordWriter(filename) 11 12 for j in range(instances_per_shard): 13 example = tf.train.Example(features=tf.train.Features(feature={ 14 'i': _int64_feature(i), 15 'j': _int64_feature(j) 16 })) 17 writer.write(example.SerializeToString()) 18 writer.close()
讀取多個TFRecord文件,獲取樣例數據,
1 files_list = tf.train.match_filenames_once('/home/error/tfrecord/data.tfrecords-*') # 參數爲正則表達式 2 filename_queue = tf.train.string_input_producer(files_list, num_epochs=2, shuffle=True) 3 4 reader = tf.TFRecordReader() 5 _, serialized_example = reader.read(filename_queue) 6 features = tf.parse_single_example( 7 serialized_example, 8 features={ 9 'i': tf.FixedLenFeature([], tf.int64), 10 'j': tf.FixedLenFeature([], tf.int64) 11 } 12 ) 13 14 with tf.Session() as sess: 15 # 雖然在本段程序中沒有聲明任何變量,可是使用tf.train.match_filenames_once函數時須要初始化一些變量 16 tf.local_variables_initializer().run() 17 18 print(sess.run(files_list)) 19 20 coord = tf.train.Coordinator()
# tf.train.string_input_producer建立文件隊列也是調用了FIFOQueue、enqueue_many、QueueRunner、add_queue_runner這幾個操做,因此須要明確調用啓動線程的語句。 21 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 22 23 for i in range(6): 24 print(sess.run([features['i'], features['j']])) 25 26 coord.request_stop() 27 coord.join(threads)
7.3.3 組合訓練數據(batching)
從文件列表中讀取單個樣例,將單個樣例進行預處理,將通過預處理的單個樣例組織成batch,提供給神經網絡輸入層。tensorflow提供了tf.train.batch和tf.train.shuffle_batch函數來將單個的樣例組織成batch形式輸出。這兩個函數都會生成一個隊列,隊列的入隊操做是生成單個樣例的方法,而每次出隊獲得的是一個batch的樣例,兩者惟一的區別在因而否將數據順序打亂。
tf.train.batch和tf.train.shuffle_batch的使用方法,
1 files_list = tf.train.match_filenames_once('/home/error/tfrecord/data.tfrecords-*') # 參數爲正則表達式 2 filename_queue = tf.train.string_input_producer(files_list, shuffle=False) 3 4 reader = tf.TFRecordReader() 5 _, serialized_example = reader.read(filename_queue) 6 features = tf.parse_single_example( 7 serialized_example, 8 features={ 9 'i': tf.FixedLenFeature([], tf.int64), 10 'j': tf.FixedLenFeature([], tf.int64) 11 } 12 ) 13 example, label = features['i'], features['j'] 14 15 batch_size = 5 16 # 隊列中最多能夠存儲的樣例個數。通常來講,隊列的大小與每一個batch的大小相關。 17 capacity = 1000 + 3 * batch_size 18 19 20 # 使用batch來組合樣例。
# capacity給出了隊列的最大容量,當隊列長度等於容量時,tensorflow暫停入隊操做,而只是等待元素出隊;當隊列長度小於容量時,tensorflow自動從新啓動入隊操做。 21 example_batch, label_batch = tf.train.batch( 22 [example, label], batch_size=batch_size, capacity=capacity 23 ) 24 25 with tf.Session() as sess: 26 # 雖然在本段程序中沒有聲明任何變量,可是使用tf.train.match_filenames_once函數時須要初始化一些變量 27 tf.local_variables_initializer().run() 28 print(sess.run(files_list)) 29 30 coord = tf.train.Coordinator() 31 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 32 33 for i in range(3): 34 cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch]) 35 print(cur_example_batch, cur_label_batch) 36 37 coord.request_stop() 38 coord.join(threads)
tf.train.batch和tf.train.shuffle_batch的區別在於,shuffle_batch多一個參數min_after_dequeue,限制了出隊時隊列中元素的最少個數。當隊列中元素太少時,隨機打亂樣例順序的做用就不大了。當隊列中元素不夠時,出隊操做將等待更多的元素入隊纔會完成。
# min_after_dequeue參數限制了出隊時最少元素的個數來保證隨機打亂順序的做用。當出隊函數被調用可是隊列中元素不夠時,出隊操做將等待更多的元素入隊纔會完成。 example_batch, label_batch = tf.train.shuffle_batch( [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=30 )
這兩個函數除了能夠將單個訓練數據整理成輸入batch,還提供了並行化處理輸入數據的方法。經過設置num_threads參數,能夠指定多個線程同時執行入隊操做。入隊操做就是數據讀取以及預處理過程。當num_threads大於1時,多個線程會同時讀取一個文件中的不一樣樣例並進行預處理。
若是須要多個線程處理不一樣文件中的樣例,可使用tf.train.batch_join和tf.train.shuffle_batch_join函數。此函數會從輸入文件隊列中獲取不一樣文件分配給不一樣的線程。通常來講,輸入文件隊列時經過tf.train.string_input_producer函數生成的,這個函數會平均分配文件以保證不一樣文件中的數據會盡可能平均地使用。
tf.train.shuffle_batch和tf.train.shuffle_batch_join均可以完成多線程並行的方式來進行數據處理,但它們各有優劣。對於shuffle_batch,不一樣線程會讀取同一個文件,若是一個文件中的樣例比較類似(好比都屬於同一個類別),那麼神經網絡的訓練效果有可能受到影響。因此使用shuffle_batch時,須要儘可能將同一個TFRecord文件中的樣例隨機打亂。而使用shuffle_batch_join時,不一樣線程會讀取不一樣文件,若是讀取數據的線程數比文件數還多,那麼多個線程可能會讀取同一個文件中相近部分的數據。並且多個線程讀取多個文件可能致使過多的硬盤尋址,從而下降讀取效率。
3個shuffle:string_input_producer中的shuffle打亂隊列中的文件;shuffle_batch中的shuffle打亂隊列中的元素;shuffle_batch_join中的shuffle。
7.3.4 輸入數據處理框架
事先準備,
把mnist數據集轉爲10個TFRecord文件,
1 import tensorflow as tf 2 import numpy as np 3 from tensorflow.examples.tutorials.mnist import input_data 4 import math 5 6 mnist = input_data.read_data_sets('/home/error/MNIST_DATA/', dtype=tf.uint8, one_hot=True) 7 8 images = mnist.train.images 9 num_examples = mnist.train.num_examples 10 labels = mnist.train.labels 11 pixels = images.shape[1] 12 height = width = int(math.sqrt(pixels)) 13 14 num_shards = 10 15 # 每一個文件有多少數據 16 instances_per_shard = int(mnist.train.num_examples / num_shards) # 5500 17 18 def _int64_feature(value): 19 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 20 21 def _bytes_feature(value): 22 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 23 24 25 for i in range(num_shards): 26 filename = '/home/error/tfrecord_mnist/data.tfrecords-%.5d-of-%.5d' % (i, num_shards) 27 writer = tf.python_io.TFRecordWriter(filename) 28 29 for j in range(instances_per_shard * i, instances_per_shard * (i+1)): 30 example = tf.train.Example(features=tf.train.Features(feature={ 31 'image': _bytes_feature(images[j].tostring()), # image[j]爲長度爲784的一維數組 32 'label': _int64_feature(np.argmax(labels[j])), 33 'height': _int64_feature(height), 34 'width': _int64_feature(width), 35 'channels': _int64_feature(1) 36 })) 37 writer.write(example.SerializeToString()) 38 writer.close()
預處理mnist數據遇到的問題:
1). 判斷類型
2). channels
33 # 隨機翻轉 34 distorted_image = tf.image.random_flip_left_right(distorted_image) 35 # 調整色彩 36 distorted_image = distort_color(distorted_image, np.random.randint(2)) 37 38 return distorted_image 39 40 ###################### 41 42 import tensorflow as tf 43 import matplotlib.pyplot as plt 44 import numpy as np 45 from meng42 import preprocessed_for_train 46 from tensorflow.examples.tutorials.mnist import input_data 47 48 49 mnist = input_data.read_data_sets('/home/yangxl/files/mnist/', dtype=tf.uint8, one_hot=True) 50 image = mnist.train.images[4] 51 image = image.reshape([28, 28, 1]) 52 53 # 預處理過程當中,`if image.dtype != tf.float32:`報錯:TypeError: data type not understood 54 # 緣由是image.dtype的類型爲numpy, 而tf.float32的類型爲tensor, 比較以前必須先統一類型。 55 image = tf.constant(image) 56 57 # 定義神經網絡的輸入大小 58 image_size = 28 59 # 預處理 60 distort_image = preprocessed_for_train(image, image_size, image_size, None) 61 distort_image = tf.squeeze(distort_image, axis=2) 62 63 with tf.Session() as sess: 64 tf.global_variables_initializer().run() 65 66 distort_image_val = sess.run(distort_image) 67 print(distort_image_val.shape) 68 plt.imshow(distort_image_val) 69 plt.show()
完整示例:
1 import tensorflow as tf 2 from meng42 import preprocessed_for_train 3 import mnist_inference 4 import os 5 6 7 files = tf.train.match_filenames_once(pattern='/home/yangxl/files/mnist_tfrecords/mnist.tfrecords-*') 8 filename_queue = tf.train.string_input_producer(files, shuffle=False, num_epochs=1) 9 10 reader = tf.TFRecordReader() 11 _, serialized_example = reader.read(filename_queue) 12 13 features = tf.parse_single_example(serialized_example, features={ 14 'image': tf.FixedLenFeature([], tf.string), 15 'label': tf.FixedLenFeature([], tf.int64), 16 'height': tf.FixedLenFeature([], tf.int64), 17 'width': tf.FixedLenFeature([], tf.int64), 18 'channels': tf.FixedLenFeature([], tf.int64) 19 }) 20 21 image, label = features['image'], features['label'] 22 height, width = features['height'], features['width'] 23 channels = features['channels'] 24 25 decoded_image = tf.decode_raw(image, tf.uint8) # shape=(?,) 26 decoded_image = tf.reshape(decoded_image, [28, 28, 1]) # shape=(28, 28, 1) 27 28 # 定義神經網絡的輸入大小 29 image_size = 28 30 # 預處理 31 distort_image = preprocessed_for_train(decoded_image, image_size, image_size, None) # shape=(28, 28, ?) 32 distort_image = tf.reshape(distort_image, [28, 28, 1]) # 預處理過程損壞了shape,會在`shuffle_batch`時報錯。 33 34 min_after_dequeue = 1000 35 batch_size = 100 36 capacity = min_after_dequeue + 3 * batch_size 37 image_batch, label_batch = tf.train.shuffle_batch([distort_image, label], batch_size, capacity, min_after_dequeue) 38 39 # 訓練 40 BATCH_SIZE = 100 41 42 LEARNING_RATE_BASE = 0.9 43 LEARNING_RATE_DECAY = 0.9 44 REGULARIZATION_RATE = 0.0001 # lambda 45 TRAINING_STEPS = 20000 46 MOVING_AVERAGE_DACAY = 0.99 47 48 MODEL_SAVE_PATH = '/home/yangxl/files/save_model2' 49 MODEL_NAME = 'yangxl.ckpt' 50 51 52 def train(image_batch, label_batch): 53 # 由於從池化層到全鏈接層要進行reshape,因此不能爲shape[0]不能爲None。 54 x = tf.placeholder(tf.float32, [BATCH_SIZE, mnist_inference.IMAGE_SIZE, mnist_inference.IMAGE_SIZE, mnist_inference.NUM_CHANNELS], 'x-input') 55 y_ = tf.placeholder(tf.int64, [BATCH_SIZE], 'y-input') 56 # 由於從tfrecords文件中讀取的label.shape=(), 因此這裏進行了相應調整(y_以及用到y_的節點,測試代碼也要對應)。 57 58 # 正則化 59 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 60 y = mnist_inference.inference(x, True, regularizer) 61 62 global_step = tf.Variable(0, trainable=False) 63 64 # 滑動平均 65 variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DACAY, global_step) 66 variables_averages_op = variables_averages.apply(tf.trainable_variables()) 67 68 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=y_) 69 cross_entropy_mean = tf.reduce_mean(cross_entropy) 70 loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) 71 72 # learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 10000 / BATCH_SIZE, LEARNING_RATE_DECAY, staircase=True) 73 learning_rate = 0.01 74 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step) 75 with tf.control_dependencies([train_step, variables_averages_op]): 76 train_op = tf.no_op(name='train') 77 78 with tf.Session() as sess: 79 tf.local_variables_initializer().run() 80 tf.global_variables_initializer().run() 81 82 coord = tf.train.Coordinator() 83 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 84 85 saver = tf.train.Saver() 86 87 image_batch_val, label_batch_val = sess.run([image_batch, label_batch]) 88 for i in range(TRAINING_STEPS): 89 _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: image_batch_val, y_: label_batch_val}) 90 91 if i % 1000 == 0: 92 print('after %d training steps, loss on training batch is %g ' % (i, loss_value)) 93 saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) 94 95 coord.request_stop() 96 coord.join(threads) 97 98 99 if __name__ == '__main__': 100 train(image_batch, label_batch)
把flower文件轉爲TFRecord文件,
1 import tensorflow as tf 2 import os 3 import glob 4 from tensorflow.python.platform import gfile 5 import numpy as np 6 7 INPUT_DATA = '/home/error/flower_photos' # 輸入文件 8 9 10 VALIDATION_PERCENTAGE = 10 11 TEST_PERCENTAGE = 10 12 13 def _int64_feature(value): 14 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 15 16 def _bytes_feature(value): 17 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 18 19 def create_image_lists(sess): 20 sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] # 當前目錄和子目錄 21 # print(sub_dirs) 22 is_root_dir = True 23 24 current_labels = 0 25 26 # 讀取全部子目錄 27 for sub_dir in sub_dirs: 28 if is_root_dir: # 把第一個排除了 29 is_root_dir = False 30 continue 31 32 # 獲取一個子目錄中全部的圖片文件 33 extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] 34 file_list = [] 35 dir_name = os.path.basename(sub_dir) # '/'最後面的部分 36 print(dir_name) 37 for extension in extensions: 38 file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension) 39 file_list.extend(glob.glob(file_glob)) # glob.glob返回一個匹配該模式的列表, glob和os配合使用來操做文件 40 if not file_list: 41 continue 42 43 OUTPUT_DATA = '/home/error/inception_v3_data/inception_v3_data_' + dir_name + '.tfrecords' # 輸出文件 44 writer = tf.python_io.TFRecordWriter(OUTPUT_DATA) 45 46 # 處理圖片數據 47 for file_name in file_list: 48 print(file_name) 49 image_raw_data = gfile.FastGFile(file_name, 'rb').read() # 二進制數據 50 image = tf.image.decode_jpeg(image_raw_data) # tensor, dtype=uint8 333×500×3 色道0~255 51 # if image.dtype != tf.float32: 52 # image = tf.image.convert_image_dtype(image, dtype=tf.float32) # 色道值0~1 53 # image = tf.image.resize_images(image, [299, 299]) 54 image_value = sess.run(image) # numpy.ndarray 55 # print(image_value.shape) 56 height, width, channles = image_value.shape 57 label = current_labels 58 example = tf.train.Example(features=tf.train.Features(feature={ 59 'image': _bytes_feature(image_value.tostring()), 60 'label': _int64_feature(np.argmax(label)), 61 'height': _int64_feature(height), 62 'width': _int64_feature(width), 63 'channels': _int64_feature(channles) 64 })) 65 writer.write(example.SerializeToString()) 66 writer.close() 67 68 current_labels += 1 69 70 71 with tf.Session() as sess: 72 create_image_lists(sess)
7.4 數據集
除隊列外,tensorflow提供了一套更高層的數據處理框架。在新的框架中,每個數據來源被抽象成一個「數據集」,開發者能夠以數據集爲基本對象,方便地進行batching、shuffle等操做。推薦使用數據集做爲輸入數據的首選框架。數據集是tensorflow的核心部件。
7.4.1 數據集的基本使用方法
在數據集框架中,每一個數據集表明一個數據來源:數據可能來自一個tensor,一個TFRecord文件,一個文本文件,或者通過sharding的一系列文件等。
因爲訓練數據一般沒法所有寫入內存中,從數據集中讀取數據時須要使用一個迭代器按順序進行讀取,這點與隊列的dequeue()操做和Reader的read()操做相似。與隊列類似,數據集也是計算圖上的一個節點。
示例,從一個張量建立一個數據集,
1 # 從數組建立數據集。不一樣數據來源,須要調用不一樣的構造方法。 2 input_data = [1, 2, 3, 4, 5] 3 dataset = tf.data.Dataset.from_tensor_slices(input_data) 4 5 # 定義一個迭代器用於遍歷數據集。由於上面定義的數據集沒有使用placeholder做爲輸入參數,因此可使用最簡單的one_shot_iterator。 6 iterator = dataset.make_one_shot_iterator() 7 8 x = iterator.get_next() 9 y = x * x 10 11 with tf.Session() as sess: 12 for i in range(len(input_data)): 13 print(sess.run([x, y]))
在真實項目中,訓練數據一般保存在硬盤文件中。好比在天然語言處理任務中,訓練數據一般以每行一條數據的形式存在文本文件中。這時能夠用TextLineDataset來構造。
1 # 從文件建立數據集 2 # windows中必需要加後綴。'D:\\files\\tf\\firsts.txt' 3 # 只有一個文件時,能夠只傳一個字符串格式的文件名。 4 input_files = ['/home/error/checkpoint', '/home/error/ten'] 5 dataset = tf.data.TextLineDataset(input_files) 6 7 iterator = dataset.make_one_shot_iterator() 8 9 x = iterator.get_next() 10 11 with tf.Session() as sess: 12 for i in range(20): 13 print(sess.run(x))
在圖像相關任務中,訓練數據一般以TFRecords形式存儲,這時能夠用TFRecordDataset來讀取數據。與文本文件不一樣的是,每一個tfrecord都有本身不一樣的feature格式,所以須要提供一個parser函數來解析所讀取的tfrecord格式的數據。
1 # 從TFRecord文件建立數據集 2 input_files = ['/home/error/tt.tfrecords', '/home/error/tt2.tfrecords'] 3 dataset = tf.data.TFRecordDataset(input_files) 4 5 # map()函數表示對數據集中的每一條數據調用相應的方法。 6 # TFRecordDataset讀出的是二進制數據,須要經過map調用parser來對二進制數據進行解析。 7 dataset = dataset.map(parser) 8 9 iterator = dataset.make_one_shot_iterator() 10 features = iterator.get_next() 11 12 with tf.Session() as sess: 13 for i in range(5): # 不能超過樣例個數,不然報錯 14 print(sess.run(features['name']))
把上面的實例改爲含有佔位符的形式:
1 def parser(record): 2 features = tf.parse_single_example( 3 record, 4 features={ 5 'name': tf.FixedLenFeature([], tf.string), 6 'image': tf.FixedLenFeature([], tf.string), 7 'label': tf.FixedLenFeature([], tf.int64), 8 'height': tf.FixedLenFeature([], tf.int64), 9 'width': tf.FixedLenFeature([], tf.int64), 10 'channels': tf.FixedLenFeature([], tf.int64) 11 } 12 ) 13 return features 14 15 # 從TFRecord文件建立數據集 16 input_files = tf.placeholder(tf.string) 17 dataset = tf.data.TFRecordDataset(input_files) 18 19 # map()函數表示對數據集中的每一條數據調用相應的方法。 20 # TFRecordDataset讀出的是二進制數據,須要經過map調用parser來對二進制數據進行解析。 21 dataset = dataset.map(parser) 22 23 iterator = dataset.make_initializable_iterator() 24 features = iterator.get_next() 25 26 with tf.Session() as sess: 27 sess.run(iterator.initializer, feed_dict={input_files: ['/home/error/tt.tfrecords', '/home/error/tt2.tfrecords']}) 28 # 由於不一樣數據來源的數據量大小難以預知。使用while True能夠把全部數據遍歷一遍。 29 while True: 30 try: 31 print(sess.run([features['name'], features['height']])) 32 except tf.errors.OutOfRangeError: 33 break
7.4.2 數據集的高層操做
dataset = dataset.map(parser)
對數據集中的每一條數據調用參數中指定的parser方法,通過處理後的數據從新組合成一個數據集。
1 distorted_image = preprocess_for_train( 2 decoded_image, image_size, image_size, None 3 ) 4 轉爲 5 dataset = dataset.map( 6 lambda x: preprocess_for_train(x, image_size, image_size, None) 7 )
這樣處理的優勢是,返回一個新數據集,能夠直接繼續調用其餘高層操做。
在隊列框架中,預處理、shuffle、batch等操做有的在隊列上進行,有的在圖片張量上進行,整個處理流程在處理隊列和張量的代碼片斷中來回切換。而在數據集操做中,全部操做都在數據集上進行。
dataset = dataset.shuffle(buffer_size) # 隨機打亂順序 dataset = dataset.batch(batch_size) # 將數據組合成batch
shuffle方法中的buffer_size等效於tf.train.shuffle_batch的min_after_dequeue,shuffle算法在內部使用一個緩衝區保存buffer_size條數據,每讀入一個新數據時,從這個緩衝區隨機選擇一條數據進行輸出。緩衝區越大,隨機性能越好,但佔用的內存也越多。
batch方法的batch_size表明要輸出的每一個batch由多少條數據組成。若是數據集包含多個張量,那麼batch操做將對每一個張量分開進行。例如,若是數據集中的每一個數據是image、label兩個張量,其中image的維度是[300, 300],label的維度是[],batch_size是128,那麼通過batch操做後的數據集的每一個輸出將包含兩個維度分別爲[128, 300, 300]和[128]的張量。
dataset = dataset.repeat(N) # 將數據集重複N份
將數據集重複N份,每一份數據被稱爲一個epoch。
須要指出的是,若是數據集在repeat以前進行了shuffle操做,輸出的每一個epoch中隨機shuffle的結果並不會相同。由於repeat和map、shuffle、batch等操做同樣,都只是計算圖上的一個計算節點,repeat只表明重複相同的處理過程,並不會記錄前一epoch的處理結果。
其餘方法,
dataset.concatenate() # 將兩個數據集順序鏈接起來 dataset.take(N) # 從數據集中讀取前N項數據 dataset.skip(N) # 在數據集中跳過前N項數據 dataset.flat_map() # 從多個數據集中輪流讀取數據
與隊列框架下的樣例不一樣的是,在訓練數據集以外,還另外讀取了測試數據集,並對測試集進行了略微不一樣的預處理。在訓練時,調用preprocessed_for_train對圖像進行隨機反轉等預處理操做;而在測試時,測試集以本來的樣子直接輸入測試。
1 import tensorflow as tf 2 from meng42 import preprocessed_for_train 3 4 train_files = tf.train.match_filenames_once('/home/yangxl/files/mnist_tfrecords/mnist.tfrecords-*') 5 test_files = tf.train.match_filenames_once('/home/yangxl/files/mnist_tfrecords/mnist.tfrecords-0000[49]-of-00010') 6 7 8 def parser(record): 9 features = tf.parse_single_example( 10 record, 11 features={ 12 'image': tf.FixedLenFeature([], tf.string), 13 'label': tf.FixedLenFeature([], tf.int64), 14 'height': tf.FixedLenFeature([], tf.int64), 15 'width': tf.FixedLenFeature([], tf.int64), 16 'channels': tf.FixedLenFeature([], tf.int64), 17 } 18 ) 19 20 decoded_image = tf.decode_raw(features['image'], tf.uint8) 21 decoded_image = tf.reshape(decoded_image, [features['height'], features['width'], features['channels']]) 22 label = features['label'] 23 return decoded_image, label 24 25 26 image_size = 28 27 batch_size = 100 28 shuffle_buffer = 1000 29 30 dataset = tf.data.TFRecordDataset(train_files) 31 dataset = dataset.map(parser) 32 # lambda中的參數image、label, 返回的是一個元組(image, label) 33 dataset = dataset.map(lambda image, label: (preprocessed_for_train(image, image_size, image_size, None), label)) 34 dataset = dataset.shuffle(shuffle_buffer).batch(batch_size) 35 # 重複NUM_EPOCHS個epoch。在7.3.4小節中TRAINING_ROUNDS指定了訓練輪數,而這裏指定了整個數據集重複的次數,這也間接肯定了訓練的輪數 36 NUM_EPOCHS = 10 37 dataset = dataset.repeat(NUM_EPOCHS) 38 39 # 雖然定義數據集時沒有直接使用placeholder來提供文件地址,可是tf.train.match_filenames_once方法獲得的結果與placeholder的機制相似,也須要初始化 40 iterator = dataset.make_initializable_iterator() 41 image_batch, label_batch = iterator.get_next() 42 print(image_batch.shape, label_batch.shape) 43 44 45 test_dataset = tf.data.TFRecordDataset(test_files) 46 # 對於測試集,不須要預處理、shuffle、repeat操做,只需用相同的parser進行解析、調整輸入層大小、batch便可 47 test_dataset = test_dataset.map(parser) 48 test_dataset = test_dataset.map(lambda image, label: (tf.image.resize_images(image, [image_size, image_size]), label)) 49 test_dataset = test_dataset.batch(batch_size) 50 51 test_iterator = test_dataset.make_initializable_iterator() 52 test_image_batch, test_label_batch = test_iterator.get_next() 53 print(test_image_batch.shape, test_label_batch.shape) 54 55 with tf.Session() as sess: 56 tf.local_variables_initializer().run() 57 print(test_files.eval())
ok!