tf.data模塊包含:dom
1 # author by FH. 2 # OverView: 3 # tf.data 4 # experimental ---Modules 5 # Dataset ---class 6 # FixedLengthRecordDataset ---class 7 # TFRecordDataset ---class 8 # TextLineDataset ---class 9 import tensorflow as tf 10 import numpy as np 11 12 13 # 1. 使用靜態方法 tf.data.Dataset.from_tensor_slices 14 # 將輸入的第一個維度切割,造成dataset 15 # 2. 使用 Dataset的 make_one_shot_iterator() 實例化一個 iterator 16 # 這個iterator 只能從頭至尾讀取一次。「one shot iterator」 17 def test1(): 18 sess = tf.Session() 19 dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0])) 20 dataset2 = tf.data.Dataset.from_tensor_slices(np.array([[1,2],[3,4],[0,9]])) 21 dataset3 = tf.data.Dataset.from_tensor_slices( 22 { 23 "a":np.array([1.0,2,3,4,5.0]), 24 "b":np.random.uniform(size=(5,2)) 25 } 26 ) 27 # 使用 Dataset的 make_one_shot_iterator() 實例化一個 iterator 28 # 這個iterator 只能從頭至尾讀取一次。「one shot iterator」 29 oneShotIterator1 = dataset1.make_one_shot_iterator() 30 oneShotIterator2 = dataset2.make_one_shot_iterator() 31 oneShotIterator3 = dataset3.make_one_shot_iterator() 32 element1 = oneShotIterator1.get_next() 33 element2 = oneShotIterator2.get_next() 34 element3 = oneShotIterator3.get_next() 35 for i in range(5): 36 print(sess.run(element1)) 37 for i in range(3): 38 print(sess.run(element2)) 39 for i in range(5): 40 print(sess.run(element3)) 41 sess.close() 42 43 # 1.Dataset 中的數據元素轉換。 44 # map() :參數爲一個函數,將dataset中的每一個元素帶入獲取新的值 45 # batch(): 參數爲一個整數,將多個元素組合成一個batch 46 def test2(): 47 sess = tf.Session() 48 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0,6])) 49 # map() 從新映射新的元素值 50 dataset1 = dataset.map(lambda x: x * 3) 51 # batch() 2個組成一個batch, 組成batch 以後size 爲3 52 dataset2 = dataset.batch(2) 53 # shuffle() 打亂dataset 54 dataset3 = dataset.shuffle(buffer_size=3) 55 # repeat() 將整個序列重複屢次,重複4次 size 爲24 56 dataset4 = dataset.repeat(4) 57 58 oneShotIterator1 = dataset1.make_one_shot_iterator() 59 oneShotIterator2 = dataset2.make_one_shot_iterator() 60 oneShotIterator3 = dataset3.make_one_shot_iterator() 61 oneShotIterator4 = dataset4.make_one_shot_iterator() 62 element1 = oneShotIterator1.get_next() 63 element2 = oneShotIterator2.get_next() 64 element3 = oneShotIterator3.get_next() 65 element4 = oneShotIterator4.get_next() 66 for i in range(6): # map() 67 print(sess.run(element1)) 68 for i in range(3): # batch() 69 print(sess.run(element2)) 70 for i in range(6): # shuffle() 71 print(sess.run(element3)) 72 for i in range(24): # repeat() 73 print(sess.run(element4)) 74 sess.close() 75 76 # example1: 讀取圖片和相應的標籤並打亂,組成 77 # batch_size=2 的數據集,重複10 epoch 78 def _parse_function(imgfilename,label): 79 image_value = tf.read_file(imgfilename) 80 img = tf.image.decode_image(image_value) 81 img = tf.image.resize_images(img,[256,256]) 82 return img,label 83 def example1(): 84 # 圖片列表 85 filesnames = tf.constant(['name1.jpg','name3.jpg','name5.jpg','name6.jpg','name7.jpg','name8.jpg']) 86 # 對應標籤 87 labels = tf.constant([0,1,0,1,1,0]) 88 # dataset (名稱,標籤) 89 dataset = tf.data.Dataset.from_tensor_slices((filesnames,labels)) 90 # map 映射成圖片和標籤 91 dataset = dataset.map(_parse_function) 92 # shuffle ,batch , repeat 93 dataset = dataset.shuffle(buffersize=3).batch(2).repeat(10) 94 return dataset 95 96 if __name__ == '__main__': 97 test2()