1.構建文件列表,使用cifar-10數據集ui
path='./data/cifar/' file_names=os.listdir(path) file_list=[os.join(path,file_name) for file_name in file_names if file_name.endswith('.bin')]
2.開啓隊列讀取文件列表code
queue_list=tf.train.string_input_producer(file_list)
3.構建閱讀器讀取數據隊列
reader=tf.FixedLengthRecordReader(self.bytes) key,value=reader.read(queue_list)
4.解析數據圖片
label_iamge=tf.decode_row(value,tf.uint8)
5.將數據切成特徵值和目標值ci
label = tf.slice(label_image, [0], [self.label_bytes]) image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
6.特徵值進行形狀改變get
image_reshape=tf.reshape(image,[self.height, self.width, self.channel])
7.進行批處理input
image_batch,label_batch=tf.train.batch([image,label]batch_size=10, num_threads=3, capacity=20)
8.開啓會話進行訓練string
with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord=coord) print(sess.run([image_batch, label_batch])) coord.request_stop() coord.join(threads)
完整代碼it
import tensorflow as tf import os class CirarReader(): def __init__(self, filelsit): self.file_list = filelsit self.height = 32 self.width = 32 self.channel = 3 self.label_bytes = 1 self.image_bytes = self.width * self.height * self.channel self.bytes = self.label_bytes + self.image_bytes def read_decode_cifar(self): queue_list = tf.train.string_input_producer(self.file_list) reader = tf.FixedLengthRecordReader(self.bytes) key, value = reader.read(queue_list) # 解析 label_image = tf.decode_raw(value, tf.uint8) # 將數據分割成標籤數據和圖片數據,特徵值和目標值 label = tf.slice(label_image, [0], [self.label_bytes]) image = tf.slice(label_image, [self.label_bytes], [self.image_bytes]) # 特徵數據形狀的改變 image_reshape = tf.reshape(image, [self.height, self.width, self.channel]) # 批處理 image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=3, capacity=20) return image_batch, label_batch if __name__ == '__main__': path = './data/cifar/' file_names = os.listdir(path) file_list = [os.path.join(path, file_name) for file_name in file_names if file_name.endswith('.bin')] reader = CirarReader(file_list) image_batch, label_batch = reader.read_decode_cifar() with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord=coord) print(sess.run([image_batch, label_batch])) coord.request_stop() coord.join(threads)