"笨方法"學習CNN圖像識別(二)—— tfrecord格式高效讀取數據

 

原文地址:https://finthon.com/learn-cnn-two-tfrecord-read-data/
-- 全文閱讀5分鐘 --python

在本文中,你將學習到如下內容:


  • 將圖片數據製做成tfrecord格式
  • 將tfrecord格式數據還原成圖片

前言

tfrecord是TensorFlow官方推薦的標準格式,可以將圖片數據和標籤一塊兒存儲成二進制文件,在TensorFlow中實現快速地複製、移動、讀取和存儲操做。訓練網絡的時候,經過創建隊列系統,能夠預先將tfrecord格式的數據加載進隊列,隊列會自動實現數據隨機或有序地進出棧,而且隊列系統和模型訓練是獨立進行的,這就加速了咱們模型的讀取和訓練。swift

準備圖片數據

按照圖片預處理教程,咱們得到了兩組resize成224*224大小的商標圖片集,把標籤分別命名成1和2兩類,以下圖:網絡

 
兩類圖片數據集

 

 
label:1

 

 
label:2


咱們如今就將這兩個類別的圖片集製做成tfrecord格式。數據結構

 

製做tfrecord格式

導入必要的庫:app

import os from PIL import Image import tensorflow as tf 

定義一些路徑和參數:函數

# 圖片路徑,兩組標籤都在該目錄下 cwd = r"./brand_picture/" # tfrecord文件保存路徑 file_path = r"./" # 每一個tfrecord存放圖片個數 bestnum = 1000 # 第幾個圖片 num = 0 # 第幾個TFRecord文件 recordfilenum = 0 # 將labels放入到classes中 classes = [] for i in os.listdir(cwd): classes.append(i) # tfrecords格式文件名 ftrecordfilename = ("traindata_63.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename)) 

bestnum控制每一個tfrecord的大小,這裏使用1000,首先定義tf.python_io.TFRecordWriter,方便後面寫入存儲數據。
製做tfrecord格式時,其實是將圖片和標籤一塊兒存儲在tf.train.Example中,它包含了一個字典,鍵是一個字符串,值的類型能夠是BytesList,FloatList和Int64List。學習

for index, name in enumerate(classes): class_path = os.path.join(cwd, name) for img_name in os.listdir(class_path): num = num + 1 if num > bestnum: #超過1000,寫入下一個tfrecord num = 1 recordfilenum += 1 ftrecordfilename = ("traindata_63.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename)) img_path = os.path.join(class_path, img_name) # 每個圖片的地址 img = Image.open(img_path, 'r') img_raw = img.tobytes() # 將圖片轉化爲二進制格式 example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), })) writer.write(example.SerializeToString()) # 序列化爲字符串 writer.close() 

在這裏咱們保存的label是classes中的編號索引,即0和1,你也能夠改爲文件名做爲label,可是必定是int類型。圖片讀取之後轉化成了二進制格式。最後經過writer寫入數據到tfrecord中。
最終咱們在當前目錄下生成一個tfrecord文件:ui

 
tfrecord文件

讀取tfrecord文件

讀取tfrecord文件是存儲的逆操做,咱們定義一個讀取tfrecord的函數,方便後面調用。spa

import tensorflow as tf def read_and_decode_tfrecord(filename): filename_deque = tf.train.string_input_producer(filename) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_deque) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = tf.cast(features['label'], tf.int32) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) img = tf.cast(img, tf.float32) / 255.0 return img, label train_list = ['traindata_63.tfrecords-000'] img, label = read_and_decode_tfrecord(train_list) 

這段代碼主要是經過tf.TFRecordReader讀取裏面的數據,而且還原數據類型,最後咱們對圖片矩陣進行歸一化。到這裏咱們就完成了tfrecord輸出,能夠對接後面的訓練網絡了。
若是咱們想直接還原成原來的圖片,就須要先註釋掉讀取tfrecord函數中的歸一化一行,並添加部分代碼,完整代碼以下:線程

import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt def read_and_decode_tfrecord(filename): filename_deque = tf.train.string_input_producer(filename) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_deque) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = tf.cast(features['label'], tf.int32) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) # img = tf.cast(img, tf.float32) / 255.0 #將矩陣歸一化0-1之間 return img, label train_list = ['traindata_63.tfrecords-000'] img, label = read_and_decode_tfrecord(train_list) img_batch, label_batch = tf.train.batch([img, label], num_threads=2, batch_size=2, capacity=1000) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 建立一個協調器,管理線程 coord = tf.train.Coordinator() # 啓動QueueRunner,此時文件名隊列已經進隊 threads = tf.train.start_queue_runners(sess=sess, coord=coord) while True: b_image, b_label = sess.run([img_batch, label_batch]) b_image = Image.fromarray(b_image[0]) plt.imshow(b_image) plt.axis('off') plt.show() coord.request_stop() # 其餘全部線程關閉以後,這一函數才能返回 coord.join(threads) 

在後面創建了一個隊列tf.train.batch,經過Session調用順序隊列系統,輸出每張圖片。Session部分在訓練網絡的時候還會講到。咱們學習tfrecord過程,能加深對數據結構和類型的理解。到這裏咱們對tfrecord格式的輸入輸出有了必定了解,咱們訓練網絡的準備工做已完成,接下來就是咱們CNN模型的搭建工做了。

可能感興趣
相關文章
相關標籤/搜索