TFrecord是一個Google提供的用於深度學習的數據格式,我的以爲很方便規範,值得學習。本文主要講的是怎麼存儲array,別的數據存儲較爲簡單,觸類旁通就行。python
在TFrecord中的數據都須要進行一個轉化的過程,這個轉化分紅三種數據結構
通常來說咱們的圖片讀進來之後是兩種形式,多線程
可是存儲在TFrecord裏面的不能是array的形式,因此咱們須要利用tostring()將上面的矩陣轉化成字符串再經過tf.train.BytesList轉化成能夠存儲的形式。學習
下面給個實例代碼,你們看看就懂了ui
adjust_pic.py : 做用就是轉化Image大小編碼
# -*- coding: utf-8 -*- import tensorflow as tf def resize(img_data, width, high, method=0): return tf.image.resize_images(img_data,[width, high], method)
pic2tfrecords.py :將圖片存成TFrecord.net
# -*- coding: utf-8 -*- # 將圖片保存成 TFRecord import os.path import matplotlib.image as mpimg import tensorflow as tf import adjust_pic as ap from PIL import Image SAVE_PATH = 'data/dataset.tfrecords' def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def load_data(datafile, width, high, method=0, save=False): train_list = open(datafile,'r') # 準備一個 writer 用來寫 TFRecord 文件 writer = tf.python_io.TFRecordWriter(SAVE_PATH) with tf.Session() as sess: for line in train_list: # 得到圖片的路徑和類型 tmp = line.strip().split(' ') img_path = tmp[0] label = int(tmp[1]) # 讀取圖片 image = tf.gfile.FastGFile(img_path, 'r').read() # 解碼圖片(若是是 png 格式就使用 decode_png) image = tf.image.decode_jpeg(image) # 轉換數據類型 # 由於爲了將圖片數據可以保存到 TFRecord 結構體中,因此須要將其圖片矩陣轉換成 string,因此爲了在使用時可以轉換回來,這裏肯定下數據格式爲 tf.float32 image = tf.image.convert_image_dtype(image, dtype=tf.float32) # 既然都將圖片保存成 TFRecord 了,那就先把圖片轉換成但願的大小吧 image = ap.resize(image, width, high) # 執行 op: image image = sess.run(image) # 將其圖片矩陣轉換成 string image_raw = image.tostring() # 將數據整理成 TFRecord 須要的數據結構 example = tf.train.Example(features=tf.train.Features(feature={ 'image_raw': _bytes_feature(image_raw), 'label': _int64_feature(label), })) # 寫 TFRecord writer.write(example.SerializeToString()) writer.close() load_data('train_list.txt_bak', 224, 224)
tfrecords2data.py :讀取Tfrecord裏的內容線程
# -*- coding: utf-8 -*- # 從 TFRecord 中讀取並保存圖片 import tensorflow as tf import numpy as np SAVE_PATH = 'data/dataset.tfrecords' def load_data(width, high): reader = tf.TFRecordReader() filename_queue = tf.train.string_input_producer([SAVE_PATH]) # 從 TFRecord 讀取內容並保存到 serialized_example 中 _, serialized_example = reader.read(filename_queue) # 讀取 serialized_example 的格式 features = tf.parse_single_example( serialized_example, features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), }) # 解析從 serialized_example 讀取到的內容 images = tf.decode_raw(features['image_raw'], tf.uint8) labels = tf.cast(features['label'], tf.int64) with tf.Session() as sess: # 啓動多線程 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 由於我這裏只有 2 張圖片,因此下面循環 2 次 for i in range(2): # 獲取一張圖片和其對應的類型 label, image = sess.run([labels, images]) # 這裏特別說明下: # 由於要想把圖片保存成 TFRecord,那就必須先將圖片矩陣轉換成 string,即: # pic2tfrecords.py 中 image_raw = image.tostring() 這行 # 因此這裏須要執行下面這行將 string 轉換回來,不然會沒法 reshape 成圖片矩陣,請看下面的小例子: # a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩陣 # b = a.tostring() # # 下面這行的輸出是 32,即: 2*2 以後還要再乘 8 # # 若是 tostring 以後的長度是 2*2=4 的話,那能夠將 b 直接 reshape([2, 2]),但如今的長度是 2*2*8 = 32,因此沒法直接 reshape # # 同理若是你的圖片是 500*500*3 的話,那 tostring() 以後的長度是 500*500*3 後再乘上一個數 # print len(b) # # 但在網上有不少提供的代碼裏都沒有下面這一行,大家那真的能 reshape ? image = np.fromstring(image, dtype=np.float32) # reshape 成圖片矩陣 image = tf.reshape(image, [224, 224, 3]) # 由於要保存圖片,因此將其轉換成 uint8 image = tf.image.convert_image_dtype(image, dtype=tf.uint8) # 按照 jpeg 格式編碼 image = tf.image.encode_jpeg(image) # 保存圖片 with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f: f.write(sess.run(image)) load_data(224, 224)
以上代碼摘自TFRecord 的使用,以爲挺好的,沒改原樣照搬,我本身作實驗時改了不少,由於我是在im2txt的基礎上寫的。code