TensorFlow 讀取二進制文件

1.構建文件列表,使用cifar-10數據集ui

path='./data/cifar/'
file_names=os.listdir(path)
file_list=[os.join(path,file_name) for file_name in file_names if file_name.endswith('.bin')]

2.開啓隊列讀取文件列表code

queue_list=tf.train.string_input_producer(file_list)

3.構建閱讀器讀取數據隊列

reader=tf.FixedLengthRecordReader(self.bytes)
key,value=reader.read(queue_list)

4.解析數據圖片

label_iamge=tf.decode_row(value,tf.uint8)

5.將數據切成特徵值和目標值ci

label = tf.slice(label_image, [0], [self.label_bytes])
   image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])

6.特徵值進行形狀改變get

image_reshape=tf.reshape(image,[self.height, self.width, self.channel])

7.進行批處理input

image_batch,label_batch=tf.train.batch([image,label]batch_size=10, num_threads=3, capacity=20)

8.開啓會話進行訓練string

with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)
        print(sess.run([image_batch, label_batch]))
        coord.request_stop()
        coord.join(threads)

完整代碼it

import tensorflow as tf
import os


class CirarReader():
    def __init__(self, filelsit):
        self.file_list = filelsit
        self.height = 32
        self.width = 32
        self.channel = 3
        self.label_bytes = 1
        self.image_bytes = self.width * self.height * self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_decode_cifar(self):
        queue_list = tf.train.string_input_producer(self.file_list)
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(queue_list)
        # 解析
        label_image = tf.decode_raw(value, tf.uint8)
        # 將數據分割成標籤數據和圖片數據,特徵值和目標值
        label = tf.slice(label_image, [0], [self.label_bytes])
        image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
        # 特徵數據形狀的改變
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
        # 批處理
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=3, capacity=20)
        return image_batch, label_batch


if __name__ == '__main__':
    path = './data/cifar/'
    file_names = os.listdir(path)
    file_list = [os.path.join(path, file_name) for file_name in file_names if file_name.endswith('.bin')]
    reader = CirarReader(file_list)
    image_batch, label_batch = reader.read_decode_cifar()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)
        print(sess.run([image_batch, label_batch]))
        coord.request_stop()
        coord.join(threads)
相關文章
相關標籤/搜索