關於Tensorflow 的數據讀取環節 tensorflow 1.0 學習:十圖詳解tensorflow數據讀取機制

Tensorflow讀取數據的通常方式有下面3種:html

  • preloaded直接建立變量:在tensorflow定義圖的過程當中,建立常量或變量來存儲數據
  • feed:在運行程序時,經過feed_dict傳入數據
  • reader從文件中讀取數據:在tensorflow圖開始時,經過一個輸入管線從文件中讀取數據

Preloaded方法的簡單例子

 1 import tensorflow as tf
 2 
 3 """定義常量"""
 4 const_var = tf.constant([1, 2, 3])
 5 """定義變量"""
 6 var = tf.Variable([1, 2, 3])
 7 
 8 with tf.Session() as sess:
 9     sess.run(tf.global_variables_initializer())
10     print(sess.run(var))
11     print(sess.run(const_var))

Feed方法

能夠在tensorflow運算圖的過程當中,將數據傳遞到事先定義好的placeholder中。方法是在調用session.run函數時,經過feed_dict參數傳入。簡單例子:python

 1 import tensorflow as tf
 2 """定義placeholder"""
 3 x1 = tf.placeholder(tf.int16)
 4 x2 = tf.placeholder(tf.int16)
 5 result = x1 + x2
 6 """定義feed_dict"""
 7 feed_dict = {
 8 x1: [10],
 9 x2: [20]
10 }
11 """運行圖"""
12 with tf.Session() as sess:
13     print(sess.run(result, feed_dict=feed_dict))

上面的兩個方法在面對大量數據時,都存在性能問題。這時候就須要使用到第3種方法,文件讀取,讓tensorflow本身從文件中讀取數據git

從文件中讀取數據

 

圖引用自 https://zhuanlan.zhihu.com/p/27238630github

步驟:
  1. 獲取文件名列表list
  2. 建立文件名隊列,調用tf.train.string_input_producer,參數包含:文件名列表,num_epochs【定義重複次數】,shuffle【定義是否打亂文件的順序】
  3. 定義對應文件的閱讀器>* tf.ReaderBase >* tf.TFRecordReader >* tf.TextLineReader >* tf.WholeFileReader >* tf.IdentityReader >* tf.FixedLengthRecordReader
  4. 解析器 >* tf.decode_csv >* tf.decode_raw >* tf.image.decode_image >* …
  5. 預處理,對原始數據進行處理,以適應network輸入所需
  6. 生成batch,調用tf.train.batch() 或者 tf.train.shuffle_batch()
  7. prefetch【可選】使用預加載隊列slim.prefetch_queue.prefetch_queue()
  8. 啓動填充隊列的線程,調用tf.train.start_queue_runners

圖引用自http://www.yyliu.cn/post/89458415.htmlapi

 讀取文件格式舉例

tensorflow支持讀取的文件格式包括:CSV文件,二進制文件,TFRecords文件,圖像文件,文本文件等等。具體使用時,須要根據文件的不一樣格式,選擇對應的文件格式閱讀器,再將文件名隊列傳爲參數,傳入閱讀器的read方法中。方法會返回key與對應的record value。將value交給解析器進行解析,轉換成網絡能進行處理的tensor。網絡

CSV文件讀取:

閱讀器:tf.TextLineReadersession

解析器:tf.decode_csvide

 1 filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
 2 """閱讀器"""
 3 reader = tf.TextLineReader()
 4 key, value = reader.read(filename_queue)
 5 """解析器"""
 6 record_defaults = [[1], [1], [1], [1]]
 7 col1, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults)
 8 features = tf.concat([col1, col2, col3, col4], axis=0)
 9 
10 with tf.Session() as sess:
11     coord = tf.train.Coordinator()
12     threads = tf.train.start_queue_runners(coord=coord)
13     for i in range(100):
14         example = sess.run(features)
15     coord.request_stop()
16     coord.join(threads)
 二進制文件讀取:

閱讀器:tf.FixedLengthRecordReader函數

解析器:tf.decode_rawpost

圖像文件讀取:

閱讀器:tf.WholeFileReader

解析器:tf.image.decode_image, tf.image.decode_gif, tf.image.decode_jpeg, tf.image.decode_png

 TFRecords文件讀取

TFRecords文件是tensorflow的標準格式。要使用TFRecords文件讀取,事先須要將數據轉換成TFRecords文件,具體可察看:convert_to_records.py 在這個腳本中,先將數據填充到tf.train.Example協議內存塊(protocol buffer),將協議內存塊序列化爲字符串,再經過tf.python_io.TFRecordWriter寫入到TFRecords文件中去。

閱讀器:tf.TFRecordReader

解析器:tf.parse_single_example

又或者使用slim提供的簡便方法:slim.dataset.Data以及slim.dataset_data_provider.DatasetDataProvider方法

 1 def get_split(record_file_name, num_sampels, size):
 2     reader = tf.TFRecordReader
 3 
 4     keys_to_features = {
 5         "image/encoded": tf.FixedLenFeature((), tf.string, ''),
 6         "image/format": tf.FixedLenFeature((), tf.string, 'jpeg'),
 7         "image/height": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)),
 8         "image/width": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)),
 9     }
10 
11     items_to_handlers = {
12         "image": slim.tfexample_decoder.Image(shape=[size, size, 3]),
13         "height": slim.tfexample_decoder.Tensor("image/height"),
14         "width": slim.tfexample_decoder.Tensor("image/width"),
15     }
16 
17     decoder = slim.tfexample_decoder.TFExampleDecoder(
18         keys_to_features, items_to_handlers
19     )
20     return slim.dataset.Dataset(
21         data_sources=record_file_name,
22         reader=reader,
23         decoder=decoder,
24         items_to_descriptions={},
25         num_samples=num_sampels
26     )
27 
28 
29 def get_image(num_samples, resize, record_file="image.tfrecord", shuffle=False):
30     provider = slim.dataset_data_provider.DatasetDataProvider(
31         get_split(record_file, num_samples, resize),
32         shuffle=shuffle
33     )
34     [data_image] = provider.get(["image"])
35     return data_image

參考資料: 

TensorFlow數據讀取方式(3種方法)

tensorflow 1.0 學習:十圖詳解tensorflow數據讀取機制

相關文章
相關標籤/搜索