tensorflow 使用tfrecords建立本身數據集

直接採用矩陣方式創建數據集見:http://www.javashuo.com/article/p-pjxbpkmd-kr.htmlhtml

製做本身的數據集(使用tfrecords)python

爲何採用這個格式?多線程

TFRecords文件格式在圖像識別中有很好的使用,其能夠將二進制數據和標籤數據(訓練的類別標籤)數據存儲在同一個文件中,它能夠在模型進行訓練以前經過預處理步驟將圖像轉換爲TFRecords格式,此格式最大的優勢實踐每幅輸入圖像和與之關聯的標籤放在同一個文件中.TFRecords文件是一種二進制文件,其不對數據進行壓縮,因此能夠被快速加載到內存中.格式不支持隨機訪問,所以它適合於大量的數據流,但不適用於快速分片或其餘非連續存取。函數

前戲:ui

tf.train.Feature
tf.train.Feature有三個屬性爲tf.train.bytes_list    tf.train.float_list    tf.train.int64_list,顯然咱們只須要根據上一步獲得的值來設置tf.train.Feature的屬性就能夠了,以下所示:spa

1 tf.train.Feature(int64_list=data_id)
2 tf.train.Feature(bytes_list=data)

tf.train.Features
從名字來看,咱們應該能猜出tf.train.Features是tf.train.Feature的複數,事實上tf.train.Features有屬性爲feature,這個屬性的通常設置方法是傳入一個字典,字典的key是字符串(feature名),而值是tf.train.Feature對象。所以,咱們能夠這樣獲得tf.train.Features對象:.net

1 feature_dict = {
2 "data_id": tf.train.Feature(int64_list=data_id),
3 "data": tf.train.Feature(bytes_list=data)
4 }
5 features = tf.train.Features(feature=feature_dict)

tf.train.Example
終於到咱們的主角了。tf.train.Example有一個屬性爲features,咱們只須要將上一步獲得的結果再次當作參數傳進來便可。
另外,tf.train.Example還有一個方法SerializeToString()須要說一下,這個方法的做用是把tf.train.Example對象序列化爲字符串,由於咱們寫入文件的時候不能直接處理對象,須要將其轉化爲字符串才能處理。
固然,既然有對象序列化爲字符串的方法,那麼確定有從字符串反序列化到對象的方法,該方法是FromString(),須要傳遞一個tf.train.Example對象序列化後的字符串進去作爲參數才能獲得反序列化的對象。
在咱們這裏,只須要構建tf.train.Example對象並序列化就能夠了,這一步的代碼爲:線程

1 example = tf.train.Example(features=features)
2 example_str = example.SerializeToString()

實例(高潮部分):code

首先看一下咱們的文件夾路徑:htm

create_tfrecords.py中寫咱們的函數

生成數據文件階段代碼以下:

 1 def creat_tf(imgpath):  2     cwd = os.getcwd()  #獲取當前路徑
 3     classes = os.listdir(cwd + imgpath)  #獲取到[1, 2]文件夾
 4     # 此處定義tfrecords文件存放
 5     writer = tf.python_io.TFRecordWriter("train.tfrecords")  6     for index, name in enumerate(classes):   #循環獲取倆文件夾(倆類別)
 7         class_path = cwd + imgpath + name + "/"
 8         if os.path.isdir(class_path):  9             for img_name in os.listdir(class_path): 10                 img_path = class_path + img_name 11                 img = Image.open(img_path) 12                 img = img.resize((224, 224)) 13                 img_raw = img.tobytes() 14                 example = tf.train.Example(features=tf.train.Features(feature={ 15                     'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])), 16                     'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) 17  })) 18  writer.write(example.SerializeToString()) 19                 print(img_name) 20     writer.close()

這段代碼主要生成  train.tfrecords 文件。

讀取數據階段代碼以下:

 1 def read_and_decode(filename):  2     # 根據文件名生成一個隊列
 3     filename_queue = tf.train.string_input_producer([filename])  4 
 5     reader = tf.TFRecordReader()  6     _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
 7     features = tf.parse_single_example(serialized_example,  8                                        features={  9                                            'label': tf.FixedLenFeature([], tf.int64), 10                                            'img_raw': tf.FixedLenFeature([], tf.string), 11  }) 12 
13     img = tf.decode_raw(features['img_raw'], tf.uint8) 14     img = tf.reshape(img, [224, 224, 3]) 15     # 轉換爲float32類型,並作歸一化處理
16     img = tf.cast(img, tf.float32)  # * (1. / 255)
17     label = tf.cast(features['label'], tf.int64) 18     return img, label

訓練階段咱們獲取數據的代碼:

 1 images, labels = read_and_decode('./train.tfrecords')  2 img_batch, label_batch = tf.train.shuffle_batch([images, labels],  3                                                 batch_size=5,  4                                                 capacity=392,  5                                                 min_after_dequeue=200)  6 init = tf.global_variables_initializer()  7 with tf.Session() as sess:  8  sess.run(init)  9     coord = tf.train.Coordinator()  #線程協調器
10     threads = tf.train.start_queue_runners(sess=sess,coord=coord) 11     # 訓練部分代碼--------------------------------
12     IMG, LAB = sess.run([img_batch, label_batch]) 13     print(IMG.shape) 14 
15     #----------------------------------------------
16     coord.request_stop()  # 協調器coord發出全部線程終止信號
17     coord.join(threads) #把開啓的線程加入主線程,等待threads結束

總結(流程):

  1. 生成tfrecord文件
  2. 定義record reader解析tfrecord文件
  3. 構造一個批生成器(batcher
  4. 構建其餘的操做
  5. 初始化全部的操做
  6. 啓動QueueRunner

備註:關於tf.train.Coordinator 詳見:

https://blog.csdn.net/dcrmg/article/details/79780331

TensorFlow的Session對象是支持多線程的,能夠在同一個會話(Session)中建立多個線程,並行執行。在Session中的全部線程都必須能被同步終止,異常必須能被正確捕獲並報告,會話終止的時候, 隊列必須能被正確地關閉。

  1. 調用 tf.train.slice_input_producer,從 本地文件裏抽取tensor,準備放入Filename Queue(文件名隊列)中;
  2. 調用 tf.train.batch,從文件名隊列中提取tensor,使用單個或多個線程,準備放入文件隊列;
  3. 調用 tf.train.Coordinator() 來建立一個線程協調器,用來管理以後在Session中啓動的全部線程;
  4. 調用tf.train.start_queue_runners, 啓動入隊線程,由多個或單個線程,按照設定規則,把文件讀入Filename Queue中。函數返回線程ID的列表,通常狀況下,系統有多少個核,就會啓動多少個入隊線程(入隊具體使用多少個線程在tf.train.batch中定義);
  5. 文件從 Filename Queue中讀入內存隊列的操做不用手動執行,由tf自動完成;
  6. 調用sess.run 來啓動數據出列和執行計算;
  7. 使用 coord.should_stop()來查詢是否應該終止全部線程,當文件隊列(queue)中的全部文件都已經讀取出列的時候,會拋出一個 OutofRangeError 的異常,這時候就應該中止Sesson中的全部線程了;
  8. 使用coord.request_stop()來發出終止全部線程的命令,使用coord.join(threads)把線程加入主線程,等待threads結束。
相關文章
相關標籤/搜索