一直以來都是用 tensorflow 框架實現深度學習算法和實驗,在網絡訓練時有一個重要的問題就是訓練數據的讀取。tensorflow 支持流水線並行讀取數據,這種方式將數據的讀取和網絡訓練並行,數據讀取效率和將全部數據載入內存後進行存取至關,卻又不會增長內存開銷,是很值得推薦的一種方式。這篇筆記就是總結一下本身在實際應用中的並行數據讀取,留個備份,隨時學習。git
主要參考了 Google HDRnet 代碼:https://github.com/mgharbi/hdrnet,CycleGAN 代碼:https://github.com/vanhuyz/CycleGAN-TensorFlowgithub
HDRnet工程裏的 data_pipeline.py 文件提供了很是清晰的流水線讀取數據示例,在官方代碼的基礎上,能夠很輕鬆地針對本身的應用實現一套數據讀取接口,假設咱們的訓練數據存儲在目錄 training_data/input 和 training_data/output,input 存儲網絡訓練輸入,output 存儲網絡目標輸出,一對訓練樣本的輸入和目標輸出名稱相同,均爲二進制文件 *.dat,如下面代碼爲示例展現如何實現流水線並行數據讀取:算法
def data_generator(params, data_path): filelist = os.listdir(data_path) # 獲取訓練目錄下的文件名列表 if params.shuffle: random.shuffle(filelist) # 隨機打亂訓練數據 input_files = [os.path.join(data_path, 'input', f) for f in filelist if f.endswith('.dat')] # 生成輸入數據文件名列表 output_files = [os.path.join(data_path, 'output', f) for f in filelist if f.endswith('.dat')] # 生成目標輸出文件名列表
# 基於給定的文件名列表,建立先入先出的文件名隊列,輸入能夠是多個文件名列表,輸出對應的對個文件名隊列 input_queue, output_queue = tf.train.slice_input_producer( [input_files, output_files], shuffle=params.shuffle, seed=params.seed, num_epochs=params.num_epochs) input_reader = tf.read_file(input_queue) # 建立 reader,讀取輸入數據 output_reader = tf.read_file(output_queue) # 建立 reader,讀取目標輸出
# 根據文件類型的不一樣解析數據,若是文件是圖像,可使用 tf.image.decode_jpeg 等函數解析 if os.path.splitext(input_files[0])[-1] == '.jpg': input = tf.image.decode_jpeg(input_reader, channels=3) else: input = tf.decode_raw(input_reader, data_type=tf.uint16) # 若是是二進制信息存儲,則可使用 tf.decode_raw 函數解析 input = tf.reshape(input, [params.height, params.width, params.channel]) # 將數據 reshape 爲正確的形狀,此處以圖像 (height, width, channel) 爲例 if os.path.splitext(output_files[0])[-1] == '.jpg': output = tf.image.decode_jpeg(output_reader, channels=3) else: output = tf.decode_raw(output_reader, data_type=tf.uint16) input = tf.reshape(input, [params.height, params.width, params.channel])
# 上面讀取了單個輸入和對應的目標輸出,網絡訓練時如需數據增廣,能夠在讀取單個訓練對以後,使用函數對數據進行處理,擴大訓練集 input, output = augment_data(input, output) samples = {} # 將增廣後的一對訓練數據組織爲字典的形式,便於後面組織成 batch samples['input'] = input samples['output'] = output if param.shuffle: # 建立批樣例訓練數據 samples = tf.train.shuffle_batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity, min_after_dequeue=params.min_after_dequeue) else: samples = tf.train.batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity) return samples # 返回一個 batch 的訓練數據
代碼中具體函數的接口能夠經過 tensorflow 的文檔查清。以上,只是聲明瞭多線程的文件讀取操做,並不會真正的讀取數據,爲了在會話執行時順利地獲取輸入數據,須要使用 tf.train.start_queue_runners 來啓動執行入隊列操做的全部線程,具體過程包括:文件名入隊到文件名隊列,樣例入隊到樣例隊列。示例代碼以下:安全
params.shuffle = true params.seed = 1234 params.height = 224 params.width = 224 params.channel = 3 training_path = 'dir/to/training/data' training_samples = data_generator(params, training_path) batch_inputs = training_samples['input'] batch_outputs = training_sample['output']
# 網絡計算圖建立
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)
loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess = sess)
sess.run(train_op)
...
上面的代碼中,輸入輸出各只有一張圖像,展現瞭如何實現流水線讀取,以及如何使用讀取出的數據。當輸入或者輸出包含多個文件時,例如,輸入是圖像和其語義分割圖,能夠在 data_generator 函數中,增長對語義分割圖的讀取,相對應的,多了 seg_files、seg_queue、seg_reader、seg_map 以及最後的 samples['seg_map'] = seg_map;網絡
一樣,當輸入數據是其它格式時,只須要根據對應的格式修改數據讀取的代碼接,例如 CycleGAN 中,訓練數據存儲爲 tfrecord 格式,須要修改的其實就是對文件的讀取部分。多線程
咱們都知道,tensorflow 在建立網絡計算圖時,一般須要爲網絡輸入和目標輸出先聲明 placeholder,可是上面的第二段示例代碼則是直接使用數據讀取的輸出構建網絡計算圖,是否是說採用這種方式就不能採用常見方法那樣,先定義 placeholder,再在網絡訓練中使用 feed_dict 填充數據呢?答案是能夠的,方法也和一般的作法沒有太大區別,示例以下:框架
x = tf.placeholder(...)
y = tf.palceholder(...) conv_1 = Conv2D(y, ...) ...
loss = tf.reduce_mean(tf.squared_difference(net_y, y))
train_op = tf.minimize(loss, ...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
samples = data_generator(params, training_path)
sess.run(train_op, feed_dict={x: samples['input'], y: samples['output']})
和第一種方法的區別是 data_generator 是在會話 sess 中調用,而不是在構建網絡計算圖時調用;dom
須要注意的是,上面的方式容錯性比較差,主要是由於採用多線程方式讀取數據,隊列操做後臺線程的生命週期無管理機制,線程出現異常會致使程序崩潰,比較常見的異常是文件名隊列或者樣例隊列越界拋出的 tf.errors.OutOfRangeError。爲了處理這種異常,HDRnet、CycleGAN 工程代碼中都使用 tf.train.Coordinator 建立了管理多線程聲明週期的協調器,其工做原理是經過監控 tensorflow 全部後臺線程,當有線程出現異常時,協調器的 should_stop 成員方法返回 True,循環結束,而後會話執行協調器的 request_stop 方法,請求全部線程安全退出。一套完整的示例代碼以下:函數
params.shuffle = true
params.seed = 1234
params.height = 224
params.width = 224
params.channel = 3
training_path = 'dir/to/training/data'
training_samples = data_generator(params, training_path)
batch_inputs = training_samples['input']
batch_outputs = training_sample['output']
# 網絡計算圖建立
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)
loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
sess.run(train_op)
except KeyboardInterrupt: # 響應 Ctrl+C 中止訓練
coord.request_stop()
except Exception as e: # 後臺線程出現異常
coord.request_stop(e)
finally: # 這一步總會執行
save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) # 保存 checkpoint
coord.request_stop()
coord.join(threads)
以上,介紹 tensorflow 中如何使用多線程並行讀取數據,如何在訓練中使用讀取的數據,以及如何對多線程進行監視,提高網絡訓練的容錯性。分享給你們,也給本身學習。學習