tensorflow製做tfrecord格式數據

tf.Example msg

tensorflow提供了一種統一的格式.tfrecord來存儲圖像數據.用的是自家的google protobuf.就是把圖像數據序列化成自定義格式的二進制數據.python

To read data efficiently it can be helpful to serialize your data and store it in a set of files (100-200MB each) that can each be read linearly. This is especially true if the data is being streamed over a network. This can also be useful for caching any data-preprocessing.git

The TFRecord format is a simple format for storing a sequence of binary records.
protobuf消息的格式以下:
https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/core/example/feature.protogithub

message BytesList {
  repeated bytes value = 1;
}
message FloatList {
  repeated float value = 1 [packed = true];
}
message Int64List {
  repeated int64 value = 1 [packed = true];
}

// Containers for non-sequential data.
message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

message Features {
  map<string, Feature> feature = 1;
};

message FeatureList {
  repeated Feature feature = 1;
};

message FeatureLists {
  map<string, FeatureList> feature_list = 1;
};

tf.Example是一個map. map的格式爲{"string": tf.train.Feature}
tf.train.Feature基本的格式有3種:app

  • tf.train.BytesList
    • string
    • byte
  • tf.train.FloatList
    • float(float32)
    • double(float64)
  • tf.train.Int64List
    • bool
    • enum
    • int32
    • unit32
    • int64
    • uint64

參考tensorflow官方文檔函數

將本身的數據製做爲tfrecord格式

完整代碼工具

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import numpy as np
import IPython.display as display
import os
import cv2 as cv
import argparse

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def convert_to_tfexample(img,label):
    """convert one img matrix into tf.Example"""
    img_raw = img.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
    'label': _int64_feature(label),
    'img': _bytes_feature(img_raw)}))
    
    return example

#path="/home/sc/disk/data/lishui/1"
def read_dataset(path):
    imgs=[]
    labels=[]
    for root, dirs, files in os.walk(path):
        for one_file in files:
            #print(os.path.join(path,one_file))
            one_file = os.path.join(path,one_file)
            if one_file.endswith("png"):
                label_file = one_file.replace('png','txt')
                if not os.path.isfile(label_file):
                    continue

                f = open(label_file)
                class_index = int(f.readline().split(' ')[0])
                labels.append(class_index)

                img = cv.imread(one_file)
                imgs.append(img)

    return imgs,labels

def arg_parse():
    parser = argparse.ArgumentParser()
    #parser.add_argument('--help',help='ex:python create_tfrecord.py -d /home/sc/disk/data/lishui/1 -o train.tfrecord')
    parser.add_argument('-d','--dir',type=str,default='./data',required='True',help='dir store images/label file')
    parser.add_argument('-o','--output',type=str,default='./outdata.tfrecord',required='True',help='output tfrecord file name')

    args = parser.parse_args()
    
    return args

def main():
    args = arg_parse()
    
    writer = tf.io.TFRecordWriter(args.output)
    #path="/home/sc/disk/data/lishui/1"

    imgs,labels = read_dataset(args.dir)   
    examples = map(convert_to_tfexample,imgs,labels)
    for example in examples:
        writer.write(example.SerializeToString())
    writer.close()

    print("write done")

if __name__ == '__main__':
    """
    usage:python create_tfrecord.py [data_path] [outrecordfile_path]
    ex:python create_tfrecord.py -d /home/sc/disk/data/lishui/1 -o train.tfrecord
    """
    main()

首先就是須要有工具函數把byte/string/float/int..等等類型的數據轉換爲tf.train.Featureui

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

接下來,對於圖片矩陣和標籤數據,咱們調用上述工具函數,將單幅圖片及其標籤信息轉換爲tf.ttrain.Example消息.google

def convert_to_tfexample(img,label):
    """convert one img matrix into tf.Example"""
    img_raw = img.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
    'label': _int64_feature(label),
    'img': _bytes_feature(img_raw)}))
    
    return example

對於個人數據,圖片以及label文件位於同一目錄.好比dir下有圖片a.png及相應的標籤信息a.txt.url

def read_dataset(path):
    imgs=[]
    labels=[]
    for root, dirs, files in os.walk(path):
        for one_file in files:
            #print(os.path.join(path,one_file))
            one_file = os.path.join(path,one_file)
            if one_file.endswith("png"):
                label_file = one_file.replace('png','txt')
                if not os.path.isfile(label_file):
                    continue

                f = open(label_file)
                class_index = int(f.readline().split(' ')[0])
                labels.append(class_index)

                img = cv.imread(one_file)
                imgs.append(img)

    return imgs,labels

遍歷data目錄,完成圖片讀取,及label讀取. 若是你的數據不是這麼存放的,就修改這個函數好了,返回值仍然是imgs,labelsspa

最後就是調用 tf.io.TFRecordWriter將每個tf.train.Example消息寫入文件保存.

def main():
    args = arg_parse()
    
    writer = tf.io.TFRecordWriter(args.output)
    #path="/home/sc/disk/data/lishui/1"

    imgs,labels = read_dataset(args.dir)   
    examples = map(convert_to_tfexample,imgs,labels)
    for example in examples:
        writer.write(example.SerializeToString())
    writer.close()

    print("write done")
相關文章
相關標籤/搜索