TensorFlow學習筆記(10):讀取文件

簡介

TensorFlow讀取數據共有三種方法:python

  • Feeding:當TensorFlow運行每步計算的時候,從Python獲取數據。在Graph的設計階段,用placeholder佔住Graph的位置,完成Graph的表達;當Graph傳給Session後,在運算時再把須要的數據從Python傳過來。segmentfault

  • Preloaded data:數據直接預加載到TensorFlow的Graph中,再把Graph傳入Session運行。只適用於小數據。多線程

  • Reading from file:在Graph中定義好文件讀取的運算節點,把Graph傳入Session運行時,執行讀取文件的運算,這樣能夠避免在Python和TensorFlow C++執行環境之間反覆傳遞數據。app

本文講解Reading from file的代碼。ide

其餘關於TensorFlow的學習筆記,請點擊入門教程函數

實現

#!/usr/bin/env python
# -*- coding=utf-8 -*-
# @author: 陳水平
# @date: 2017-02-19
# @description: modified program to illustrate reading from file based on TF offitial tutorial
# @ref: https://www.tensorflow.org/programmers_guide/reading_data

def read_my_file_format(filename_queue):
  """從文件名隊列讀取一行數據
  
  輸入:
  -----
  filename_queue:文件名隊列,舉個例子,能夠使用`tf.train.string_input_producer(["file0.csv", "file1.csv"])`方法建立一個包含兩個CSV文件的隊列
  
  輸出:
  -----
  一個樣本:`[features, label]`
  """
  reader = tf.SomeReader()  # 建立Reader
  key, record_string = reader.read(filename_queue)  # 讀取一行記錄
  example, label = tf.some_decoder(record_string)  # 解析該行記錄
  processed_example = some_processing(example)  # 對特徵進行預處理
  return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
  """ 從一組文件中讀取一個批次數據
  
  輸入:
  -----
  filenames:文件名列表,如`["file0.csv", "file1.csv"]`
  batch_size:每次讀取的樣本數
  num_epochs:每一個文件的讀取次數
  
  輸出:
  -----
  一批樣本,`[[example1, label1], [example2, label2], ...]`
  """
  filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)  # 建立文件名隊列
  example, label = read_my_file_format(filename_queue)  # 讀取一個樣本
  # 將樣本放進樣本隊列,每次輸出一個批次樣本
  #   - min_after_dequeue:定義輸出樣本後的隊列最小樣本數,越大隨機性越強,但start up時間和內存佔用越多
  #   - capacity:隊列大小,必須比min_after_dequeue大
  min_after_dequeue = 10000
  capacity = min_after_dqueue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
    [example, label], batch_size=batch_size, capacity=capacity,
    min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch
  
def main(_):
  x, y = input_pipeline(['file0.csv', 'file1.csv'], 1000, 5)
  train_op = some_func(x, y)
  init_op = tf.global_variables_initializer()
  local_init_op = tf.local_variables_initializer()  # local variables like epoch_num, batch_size
  sess = tf.Session()
  
  sess.run(init_op)
  sess.run(local_init_op)
  
  # `QueueRunner`用於建立一系列線程,反覆地執行`enqueue` op
  # `Coordinator`用於讓這些線程一塊兒結束
  # 典型應用場景:
  #   - 多線程準備樣本數據,執行enqueue將樣本放進一個隊列
  #   - 一個訓練線程從隊列執行dequeu獲取一批樣本,執行training op
  # `tf.train`的許多函數會在graph中添加`QueueRunner`對象,如`tf.train.string_input_producer`
  # 在執行training op以前,須要保證Queue裏有數據,所以須要先執行`start_queue_runners`
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  
  try:
    while not coord.should_stop():
      sess.run(train_op)
  except tf.errors.OutOfRangeError:
    print 'Done training -- epoch limit reached'
  finally:
    coord.request_stop()
  
  # Wait for threads to finish  
  coord.join(threads)
  sess.close()
  
if __name__ == '__main__':
  tf.app.run()
相關文章
相關標籤/搜索