TFRecord 存入圖像和標籤

#-*- coding:utf-8 -*-
import os
import tensorflow as tf
import cv2

'''
文件目錄爲
chiwawa/
     xx.jpg
     xx.jpg
     .....
japandog/
     xx.jpg
     xx.jpg
     .....
'''
cwd = 'f:/py/tfrecord/'
classes={'chiwawa','japandog'} # 須要存入的標籤,儘可能與文件名一致,方便操做

sess = tf.Session()
writer = tf.python_io.TFRecordWriter("f:/py/tfrecord/train.tfrecords") # 創建一個writer
for index, name in enumerate(classes):
    class_path = cwd + name + "/"           # 構建文件路徑
    for img_name in os.listdir(class_path): # 遍歷目錄下的文件
        img_path = class_path + img_name     # 構建具體每一張圖片的路徑
        image = cv2.imread(img_path)        # 讀取圖片

        # 獲取圖片的寬,高和通道數
        img_w = image.shape[0]
        img_h = image.shape[1]
        img_c = image.shape[2]

        # tf讀取圖片
        img = tf.read_file(img_path)
        img = tf.image.decode_jpeg(img)

        # img = tf.image.resize_images(img,(224, 224)) 改變大小
        img_raw = sess.run(tf.cast(img,tf.uint8)).tostring()              #將圖片轉化爲原生bytes
        

        label = name.encode('utf-8')  #將標籤轉化爲bytes
        '''
        如下是Example類的經常使用固定格式,但要注意第一個features有s,對應的是tf.train.Features
        tf.train.Features裏的feature是沒有s的,bytes_list對應的是tf.train.BytesList,
        int64_list對應的是tf.train.Int64List,輸入的value的格式也要一致,可輸入的格式有int,float,bytes
        label和img_raw的格式是bytes,寬、高、通道數的格式是int
        '''
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            'img_w': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_w])),
            'img_h': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_h])),
            'img_c': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_c]))
        }))
        writer.write(example.SerializeToString())  #序列化爲字符串
        
writer.close() 
相關文章
相關標籤/搜索