直接採用矩陣方式創建數據集見:http://www.javashuo.com/article/p-pjxbpkmd-kr.htmlhtml
製做本身的數據集(使用tfrecords)python
爲何採用這個格式?多線程
TFRecords文件格式在圖像識別中有很好的使用,其能夠將二進制數據和標籤數據(訓練的類別標籤)數據存儲在同一個文件中,它能夠在模型進行訓練以前經過預處理步驟將圖像轉換爲TFRecords格式,此格式最大的優勢實踐每幅輸入圖像和與之關聯的標籤放在同一個文件中.TFRecords文件是一種二進制文件,其不對數據進行壓縮,因此能夠被快速加載到內存中.格式不支持隨機訪問,所以它適合於大量的數據流,但不適用於快速分片或其餘非連續存取。函數
前戲:ui
tf.train.Feature
tf.train.Feature有三個屬性爲tf.train.bytes_list tf.train.float_list tf.train.int64_list,顯然咱們只須要根據上一步獲得的值來設置tf.train.Feature的屬性就能夠了,以下所示:spa
1 tf.train.Feature(int64_list=data_id) 2 tf.train.Feature(bytes_list=data)
tf.train.Features
從名字來看,咱們應該能猜出tf.train.Features是tf.train.Feature的複數,事實上tf.train.Features有屬性爲feature,這個屬性的通常設置方法是傳入一個字典,字典的key是字符串(feature名),而值是tf.train.Feature對象。所以,咱們能夠這樣獲得tf.train.Features對象:.net
1 feature_dict = { 2 "data_id": tf.train.Feature(int64_list=data_id), 3 "data": tf.train.Feature(bytes_list=data) 4 } 5 features = tf.train.Features(feature=feature_dict)
tf.train.Example
終於到咱們的主角了。tf.train.Example有一個屬性爲features,咱們只須要將上一步獲得的結果再次當作參數傳進來便可。
另外,tf.train.Example還有一個方法SerializeToString()須要說一下,這個方法的做用是把tf.train.Example對象序列化爲字符串,由於咱們寫入文件的時候不能直接處理對象,須要將其轉化爲字符串才能處理。
固然,既然有對象序列化爲字符串的方法,那麼確定有從字符串反序列化到對象的方法,該方法是FromString(),須要傳遞一個tf.train.Example對象序列化後的字符串進去作爲參數才能獲得反序列化的對象。
在咱們這裏,只須要構建tf.train.Example對象並序列化就能夠了,這一步的代碼爲:線程
1 example = tf.train.Example(features=features) 2 example_str = example.SerializeToString()
實例(高潮部分):code
首先看一下咱們的文件夾路徑:htm
create_tfrecords.py中寫咱們的函數
生成數據文件階段代碼以下:
1 def creat_tf(imgpath): 2 cwd = os.getcwd() #獲取當前路徑 3 classes = os.listdir(cwd + imgpath) #獲取到[1, 2]文件夾 4 # 此處定義tfrecords文件存放 5 writer = tf.python_io.TFRecordWriter("train.tfrecords") 6 for index, name in enumerate(classes): #循環獲取倆文件夾(倆類別) 7 class_path = cwd + imgpath + name + "/" 8 if os.path.isdir(class_path): 9 for img_name in os.listdir(class_path): 10 img_path = class_path + img_name 11 img = Image.open(img_path) 12 img = img.resize((224, 224)) 13 img_raw = img.tobytes() 14 example = tf.train.Example(features=tf.train.Features(feature={ 15 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])), 16 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) 17 })) 18 writer.write(example.SerializeToString()) 19 print(img_name) 20 writer.close()
這段代碼主要生成 train.tfrecords 文件。
讀取數據階段代碼以下:
1 def read_and_decode(filename): 2 # 根據文件名生成一個隊列 3 filename_queue = tf.train.string_input_producer([filename]) 4 5 reader = tf.TFRecordReader() 6 _, serialized_example = reader.read(filename_queue) # 返回文件名和文件 7 features = tf.parse_single_example(serialized_example, 8 features={ 9 'label': tf.FixedLenFeature([], tf.int64), 10 'img_raw': tf.FixedLenFeature([], tf.string), 11 }) 12 13 img = tf.decode_raw(features['img_raw'], tf.uint8) 14 img = tf.reshape(img, [224, 224, 3]) 15 # 轉換爲float32類型,並作歸一化處理 16 img = tf.cast(img, tf.float32) # * (1. / 255) 17 label = tf.cast(features['label'], tf.int64) 18 return img, label
訓練階段咱們獲取數據的代碼:
1 images, labels = read_and_decode('./train.tfrecords') 2 img_batch, label_batch = tf.train.shuffle_batch([images, labels], 3 batch_size=5, 4 capacity=392, 5 min_after_dequeue=200) 6 init = tf.global_variables_initializer() 7 with tf.Session() as sess: 8 sess.run(init) 9 coord = tf.train.Coordinator() #線程協調器 10 threads = tf.train.start_queue_runners(sess=sess,coord=coord) 11 # 訓練部分代碼-------------------------------- 12 IMG, LAB = sess.run([img_batch, label_batch]) 13 print(IMG.shape) 14 15 #---------------------------------------------- 16 coord.request_stop() # 協調器coord發出全部線程終止信號 17 coord.join(threads) #把開啓的線程加入主線程,等待threads結束
總結(流程):
record reader
解析tfrecord文件batcher
)QueueRunner
備註:關於tf.train.Coordinator 詳見:
https://blog.csdn.net/dcrmg/article/details/79780331
TensorFlow的Session對象是支持多線程的,能夠在同一個會話(Session)中建立多個線程,並行執行。在Session中的全部線程都必須能被同步終止,異常必須能被正確捕獲並報告,會話終止的時候, 隊列必須能被正確地關閉。