TensorFlow TFRecords 寫入和讀取

一.圖片和圖片標籤寫入TFRecords
1.建立文件存儲器app

writer = tf.io.TFRecordWriter('./data/tfrecords/cifar.tfrecords')

2.for循環將讀取的數據存入導example存入TFRecordsui

#每一個batch存十個圖片數據
for i in range(10)
    image=image_batch[i].eval().tostring()
    lable=label_batch[i].eval()[0]
    tf.train.Example(
    features=tf.train.Features(
    feature={ 'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
    "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    }))
    writer.write(example.SerializeToString())
writer.close()

二.讀取TFRecords
1.構建文件隊列code

queue_list=tf.train.string_input_producer(['./data/tfrecords/cifar.tfrecords'])

2.構造閱讀器隊列

reader=tf.TFRecordReader()
key,value=reader.read(queue_list)

3.解析讀取的example圖片

features=tf.parse_single_example(value,features={
'image':tf.FixedLengthFeature([],tf.string(0)
'label':tf.FixedLengthFeature([],tf.int64)
})

4.解碼內容ci

image=tf.decode_raw(features['image',tf.uint8])
label=tf.cast(features['label'],tf.float32)

5.固定圖片形狀input

image_reshape=tf.reshape(image,[self.height, self.width, self.channel])

6.進行批處理string

batch_image,label_batch=tf.train.batch([image_reshape,label],batch_size=10,num_threads=1,capacity=10)

完整的讀取代碼it

import tensorflow as tf
import os

FLAGS = tf.app.flags.FLAGS
cifar_tfrecords = tf.app.flags.DEFINE_string('cifar_tfrecords', './data/tfrecords/cifar.tfrecords', 'tfrecords目錄')


class CirarReader():
    def __init__(self, filelsit):
        self.file_list = filelsit
        self.height = 32
        self.width = 32
        self.channel = 3
        self.label_bytes = 1
        self.image_bytes = self.width * self.height * self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_decode_cifar(self):
        queue_list = tf.train.string_input_producer(self.file_list)
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(queue_list)
        # 解析
        label_image = tf.decode_raw(value, tf.uint8)
        # 將數據分割成標籤數據和圖片數據,特徵值和目標值
        label = tf.slice(label_image, [0], [self.label_bytes])
        image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
        # 特徵數據形狀的改變
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
        # 批處理
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=3, capacity=20)
        return image_batch, label_batch

    def write_to_tfrecords(self, image_batch, label_batch):
        # 簡歷文件存儲器
        writer = tf.io.TFRecordWriter('./data/tfrecords/cifar.tfrecords')
        # 循環寫入每個樣本每張圖片都要構造example協議
        for i in range(10):
            # 獲取圖片的值
            image = image_batch[1].eval().tostring()
            # 獲取標籤的值
            label = label_batch[i].eval()[0]
            # 建立example從存儲圖片和標籤
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))
            writer.write(example.SerializeToString())
        writer.close()

    def read_from_tfrecords(self):
        # 構建隊列
        queue_list = tf.train.string_input_producer(['./data/tfrecords/cifar.tfrecords'])
        # 構建閱讀器讀取數據
        reader = tf.TFRecordReader()
        key, value = reader.read(queue_list)
        # 解析數據
        features = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        # 解碼內容,
        image = tf.decode_raw(features['image'], tf.uint8)
        label = tf.cast(features['label'], tf.float32)
        # 固定圖片形狀
        image_shape = tf.reshape(image, [self.height, self.width, self.channel])
        # 進行批處理
        batch_image, batch_label = tf.train.batch([image_shape, label], batch_size=10, num_threads=1, capacity=10)
        return batch_image, batch_label


if __name__ == '__main__':
    path = './data/cifar/'
    file_names = os.listdir(path)
    file_list = [os.path.join(path, file_name) for file_name in file_names if file_name.endswith('.bin')]
    reader = CirarReader(file_list)
    # image_batch, label_batch = reader.read_decode_cifar()
    batch_image, batch_label = reader.read_from_tfrecords()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)
        # 寫入tfRecords文件
        # print("-------start----------")
        # reader.write_to_tfrecords(image_batch, label_batch)
        # print("-------end------------")
        print(sess.run([batch_image, batch_label]))
        coord.request_stop()
        coord.join(threads)
相關文章
相關標籤/搜索