十圖詳解TensorFlow數據讀取機制python
1). 建立文件名列表api
相關函數:tf.train.match_filenames_once網絡
2). 建立文件名隊列ide
相關函數:tf.train.string_input_producer函數
3). 建立Reader讀取數據ui
tf.ReaderBase 、 tf.TFRecordReader 、 tf.TextLineReader 、 tf.WholeFileReader 、 tf.IdentityReader 、 tf.FixedLengthRecordReader …spa
4).建立decoder解碼器轉換格式線程
tf.decode_csv 、 tf.decode_raw 、 tf.image.decode_image …code
5). 建立樣例隊列orm
相關函數:tf.train.shuffle_batch
閱讀器:tf.TextLineReader
解析器:tf.decode_csv
閱讀器:tf.FixedLengthRecordReader
解析器:tf.decode_raw
閱讀器:tf.WholeFileReader
解析器:tf.image.decode_image, tf.image.decode_gif, tf.image.decode_jpeg, tf.image.decode_png
閱讀器:tf.TFRecordReader
解析器:tf.parse_single_example
又或者使用slim提供的簡便方法:slim.dataset.Dataset以及slim.dataset_data_provider.DatasetDataProvider方法,通常slim.dataset.Dataset做爲函數返回,須要接收Reader和Decoder做爲參數。
def get_split(record_file_name, num_sampels, size): reader = tf.TFRecordReader keys_to_features = { "image/encoded": tf.FixedLenFeature((), tf.string, ''), "image/format": tf.FixedLenFeature((), tf.string, 'jpeg'), "image/height": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)), "image/width": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)), } items_to_handlers = { "image": slim.tfexample_decoder.Image(shape=[size, size, 3]), "height": slim.tfexample_decoder.Tensor("image/height"), "width": slim.tfexample_decoder.Tensor("image/width"), } decoder = slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers ) return slim.dataset.Dataset( data_sources=record_file_name, reader=reader, decoder=decoder, items_to_descriptions={}, num_samples=num_sampels ) def get_image(num_samples, resize, record_file="image.tfrecord", shuffle=False): provider = slim.dataset_data_provider.DatasetDataProvider( get_split(record_file, num_samples, resize), # slim.dataset.Dataset 作參數 shuffle=shuffle ) [data_image] = provider.get(["image"]) # Provider經過TFR字段獲取batch size數據 return data_image
filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle, num_epochs=epochs) reader = tf.WholeFileReader() _, img_bytes = reader.read(filename_queue) image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
filename_queue = tf.train.string_input_producer(filenames)
# 初始化閱讀器,這裏以定長字節閱讀器爲例,實際讀取圖片通常使用WholeFileReader reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) # 指定被閱讀文件 result.key, value = reader.read(filename_queue)
# Convert from a string to a vector of uint8 that is record_bytes long. # read出來的是一個二進制的string,將它解碼依照uint8格式解碼 record_bytes = tf.decode_raw(value, tf.uint8) …… ……
因爲讀取來的tensor不具備靜態shape,須要使用tensor.set_shape()指定shape(或者在處理中顯示的賦予shape如使用reshape等函數),不然沒法創建圖
read_input.label.set_shape([1])
將最後的規則tensor傳入batch生成池節點中,輸出的張量能夠直接feed進網絡
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size) …… …… image_batch, label_batch = sess.run([images_train, labels_train]) _, loss_value = sess.run( [train_op, loss], feed_dict={image_holder:image_batch, label_holder:label_batch})
# 啓動數據加強隊列 tf.train.start_queue_runners()
附上線程控制組件使用示意,
import tensorflow as tf sess = tf.Session() coord = tf.train.coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord) # 訓練過程 coord.request_stop() coord.join(threads)