tensorflow進階筆記 --- #"3"# --- 關於怎麼在TFrecord中存儲圖像的array

TFrecord是一個Google提供的用於深度學習的數據格式,我的以爲很方便規範,值得學習。本文主要講的是怎麼存儲array,別的數據存儲較爲簡單,觸類旁通就行。python

在TFrecord中的數據都須要進行一個轉化的過程,這個轉化分紅三種數據結構

  • int64
  • float
  • bytes

通常來說咱們的圖片讀進來之後是兩種形式,多線程

  1. tf.image.decode_jpeg 解碼圖片讀取成 (width,height,channels)的矩陣,這個讀取的方式和cv2.imread以及ndimage.imread同樣
  2. tf.image.convert_image_dtype會將讀進來的上面的矩陣歸一化,通常來說咱們都要進行這個歸一化的過程。歸一化的好處能夠去查。

可是存儲在TFrecord裏面的不能是array的形式,因此咱們須要利用tostring()將上面的矩陣轉化成字符串再經過tf.train.BytesList轉化成能夠存儲的形式。學習

下面給個實例代碼,你們看看就懂了ui

adjust_pic.py : 做用就是轉化Image大小編碼

# -*- coding: utf-8 -*-  
  
import tensorflow as tf  
  
def resize(img_data, width, high, method=0):  
    return tf.image.resize_images(img_data,[width, high], method)

pic2tfrecords.py :將圖片存成TFrecord.net

# -*- coding: utf-8 -*-  
# 將圖片保存成 TFRecord  
import os.path  
import matplotlib.image as mpimg  
import tensorflow as tf  
import adjust_pic as ap  
from PIL import Image  
  
  
SAVE_PATH = 'data/dataset.tfrecords'  
  
  
def _int64_feature(value):  
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
  
def _bytes_feature(value):  
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  
def load_data(datafile, width, high, method=0, save=False):  
    train_list = open(datafile,'r')  
    # 準備一個 writer 用來寫 TFRecord 文件  
    writer = tf.python_io.TFRecordWriter(SAVE_PATH)  
  
    with tf.Session() as sess:  
        for line in train_list:  
            # 得到圖片的路徑和類型  
            tmp = line.strip().split(' ')  
            img_path = tmp[0]  
            label = int(tmp[1])  
  
            # 讀取圖片  
            image = tf.gfile.FastGFile(img_path, 'r').read()  
            # 解碼圖片(若是是 png 格式就使用 decode_png)  
            image = tf.image.decode_jpeg(image)  
            # 轉換數據類型  
            # 由於爲了將圖片數據可以保存到 TFRecord 結構體中,因此須要將其圖片矩陣轉換成 string,因此爲了在使用時可以轉換回來,這裏肯定下數據格式爲 tf.float32  
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)  
            # 既然都將圖片保存成 TFRecord 了,那就先把圖片轉換成但願的大小吧  
            image = ap.resize(image, width, high)  
            # 執行 op: image  
            image = sess.run(image)  
              
            # 將其圖片矩陣轉換成 string  
            image_raw = image.tostring()  
            # 將數據整理成 TFRecord 須要的數據結構  
            example = tf.train.Example(features=tf.train.Features(feature={  
                'image_raw': _bytes_feature(image_raw),  
                'label': _int64_feature(label),  
                }))  
  
            # 寫 TFRecord  
            writer.write(example.SerializeToString())  
  
    writer.close()  
  
  
load_data('train_list.txt_bak', 224, 224)

tfrecords2data.py :讀取Tfrecord裏的內容線程

# -*- coding: utf-8 -*-  
# 從 TFRecord 中讀取並保存圖片  
import tensorflow as tf  
import numpy as np  
  
  
SAVE_PATH = 'data/dataset.tfrecords'  
  
  
def load_data(width, high):  
    reader = tf.TFRecordReader()  
    filename_queue = tf.train.string_input_producer([SAVE_PATH])  
  
    # 從 TFRecord 讀取內容並保存到 serialized_example 中  
    _, serialized_example = reader.read(filename_queue)  
    # 讀取 serialized_example 的格式  
    features = tf.parse_single_example(  
        serialized_example,  
        features={  
            'image_raw': tf.FixedLenFeature([], tf.string),  
            'label': tf.FixedLenFeature([], tf.int64),  
        })  
  
    # 解析從 serialized_example 讀取到的內容  
    images = tf.decode_raw(features['image_raw'], tf.uint8)  
    labels = tf.cast(features['label'], tf.int64)  
  
    with tf.Session() as sess:  
        # 啓動多線程  
        coord = tf.train.Coordinator()  
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)  
  
        # 由於我這裏只有 2 張圖片,因此下面循環 2 次  
        for i in range(2):  
            # 獲取一張圖片和其對應的類型  
            label, image = sess.run([labels, images])  
            # 這裏特別說明下:  
            #   由於要想把圖片保存成 TFRecord,那就必須先將圖片矩陣轉換成 string,即:  
            #       pic2tfrecords.py 中 image_raw = image.tostring() 這行  
            #   因此這裏須要執行下面這行將 string 轉換回來,不然會沒法 reshape 成圖片矩陣,請看下面的小例子:  
            #       a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩陣  
            #       b = a.tostring()  
            #       # 下面這行的輸出是 32,即: 2*2 以後還要再乘 8  
            #       # 若是 tostring 以後的長度是 2*2=4 的話,那能夠將 b 直接 reshape([2, 2]),但如今的長度是 2*2*8 = 32,因此沒法直接 reshape  
            #       # 同理若是你的圖片是 500*500*3 的話,那 tostring() 以後的長度是 500*500*3 後再乘上一個數  
            #       print len(b)  
            #  
            #   但在網上有不少提供的代碼裏都沒有下面這一行,大家那真的能 reshape ?  
            image = np.fromstring(image, dtype=np.float32)  
            # reshape 成圖片矩陣  
            image = tf.reshape(image, [224, 224, 3])  
            # 由於要保存圖片,因此將其轉換成 uint8  
            image = tf.image.convert_image_dtype(image, dtype=tf.uint8)  
            # 按照 jpeg 格式編碼  
            image = tf.image.encode_jpeg(image)  
            # 保存圖片  
            with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:  
                f.write(sess.run(image))  
  
  
load_data(224, 224)

以上代碼摘自TFRecord 的使用,以爲挺好的,沒改原樣照搬,我本身作實驗時改了不少,由於我是在im2txt的基礎上寫的。code

相關文章
相關標籤/搜索