TFRecord 的使用

什麼是 TFRecord 

            PS:這段內容摘自 http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.htmlhtml

            一種保存記錄的方法能夠容許你講任意的數據轉換爲TensorFlow所支持的格式, 這種方法可使TensorFlow的數據集更容易與網絡應用架構相匹配。這種建議的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 協議內存塊(protocol buffer)(協議內存塊包含了字段 Features)。你能夠寫一段代碼獲取你的數據, 將數據填入到Example協議內存塊(protocolbuffer),將協議內存塊序列化爲一個字符串, 而且經過tf.python_io.TFRecordWriterclass寫入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是這樣的一個例子。
            從TFRecords文件中讀取數據, 可使用tf.TFRecordReader的tf.parse_single_example解析器。這個parse_single_example操做能夠將Example協議內存塊(protocolbuffer)解析爲張量。 MNIST的例子就使用了convert_to_records 所構建的數據。 請參看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py, python

 

代碼

            adjust_pic.py網絡

                單純的轉換圖片大小數據結構

 

[python] view plain copy
 
  1. # -*- coding: utf-8 -*-  
  2.   
  3. import tensorflow as tf  
  4.   
  5. def resize(img_data, width, high, method=0):  
  6.     return tf.image.resize_images(img_data,[width, high], method)  

 

 

                pic2tfrecords.py多線程

                將圖片保存成TFRecord架構

 

[python] view plain copy
 
  1. # -*- coding: utf-8 -*-  
  2. # 將圖片保存成 TFRecord  
  3. import os.path  
  4. import matplotlib.image as mpimg  
  5. import tensorflow as tf  
  6. import adjust_pic as ap  
  7. from PIL import Image  
  8.   
  9.   
  10. SAVE_PATH = 'data/dataset.tfrecords'  
  11.   
  12.   
  13. def _int64_feature(value):  
  14.     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
  15.   
  16. def _bytes_feature(value):  
  17.     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  18.   
  19. def load_data(datafile, width, high, method=0, save=False):  
  20.     train_list = open(datafile,'r')  
  21.     # 準備一個 writer 用來寫 TFRecord 文件  
  22.     writer = tf.python_io.TFRecordWriter(SAVE_PATH)  
  23.   
  24.     with tf.Session() as sess:  
  25.         for line in train_list:  
  26.             # 得到圖片的路徑和類型  
  27.             tmp = line.strip().split(' ')  
  28.             img_path = tmp[0]  
  29.             label = int(tmp[1])  
  30.   
  31.             # 讀取圖片  
  32.             image = tf.gfile.FastGFile(img_path, 'r').read()  
  33.             # 解碼圖片(若是是 png 格式就使用 decode_png)  
  34.             image = tf.image.decode_jpeg(image)  
  35.             # 轉換數據類型  
  36.             # 由於爲了將圖片數據可以保存到 TFRecord 結構體中,因此須要將其圖片矩陣轉換成 string,因此爲了在使用時可以轉換回來,這裏肯定下數據格式爲 tf.float32  
  37.             image = tf.image.convert_image_dtype(image, dtype=tf.float32)  
  38.             # 既然都將圖片保存成 TFRecord 了,那就先把圖片轉換成但願的大小吧  
  39.             image = ap.resize(image, width, high)  
  40.             # 執行 op: image  
  41.             image = sess.run(image)  
  42.               
  43.             # 將其圖片矩陣轉換成 string  
  44.             image_raw = image.tostring()  
  45.             # 將數據整理成 TFRecord 須要的數據結構  
  46.             example = tf.train.Example(features=tf.train.Features(feature={  
  47.                 'image_raw': _bytes_feature(image_raw),  
  48.                 'label': _int64_feature(label),  
  49.                 }))  
  50.   
  51.             # 寫 TFRecord  
  52.             writer.write(example.SerializeToString())  
  53.   
  54.     writer.close()  
  55.   
  56.   
  57. load_data('train_list.txt_bak', 224, 224)  

 

 

                tfrecords2data.py網站

                從TFRecord中讀取並保存成圖片ui

 

[python] view plain copy
 
  1. # -*- coding: utf-8 -*-  
  2. # 從 TFRecord 中讀取並保存圖片  
  3. import tensorflow as tf  
  4. import numpy as np  
  5.   
  6.   
  7. SAVE_PATH = 'data/dataset.tfrecords'  
  8.   
  9.   
  10. def load_data(width, high):  
  11.     reader = tf.TFRecordReader()  
  12.     filename_queue = tf.train.string_input_producer([SAVE_PATH])  
  13.   
  14.     # 從 TFRecord 讀取內容並保存到 serialized_example 中  
  15.     _, serialized_example = reader.read(filename_queue)  
  16.     # 讀取 serialized_example 的格式  
  17.     features = tf.parse_single_example(  
  18.         serialized_example,  
  19.         features={  
  20.             'image_raw': tf.FixedLenFeature([], tf.string),  
  21.             'label': tf.FixedLenFeature([], tf.int64),  
  22.         })  
  23.   
  24.     # 解析從 serialized_example 讀取到的內容  
  25.     images = tf.decode_raw(features['image_raw'], tf.uint8)  
  26.     labels = tf.cast(features['label'], tf.int64)  
  27.   
  28.     with tf.Session() as sess:  
  29.         # 啓動多線程  
  30.         coord = tf.train.Coordinator()  
  31.         threads = tf.train.start_queue_runners(sess=sess, coord=coord)  
  32.   
  33.         # 由於我這裏只有 2 張圖片,因此下面循環 2 次  
  34.         for i in range(2):  
  35.             # 獲取一張圖片和其對應的類型  
  36.             label, image = sess.run([labels, images])  
  37.             # 這裏特別說明下:  
  38.             #   由於要想把圖片保存成 TFRecord,那就必須先將圖片矩陣轉換成 string,即:  
  39.             #       pic2tfrecords.py 中 image_raw = image.tostring() 這行  
  40.             #   因此這裏須要執行下面這行將 string 轉換回來,不然會沒法 reshape 成圖片矩陣,請看下面的小例子:  
  41.             #       a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩陣  
  42.             #       b = a.tostring()  
  43.             #       # 下面這行的輸出是 32,即: 2*2 以後還要再乘 8  
  44.             #       # 若是 tostring 以後的長度是 2*2=4 的話,那能夠將 b 直接 reshape([2, 2]),但如今的長度是 2*2*8 = 32,因此沒法直接 reshape  
  45.             #       # 同理若是你的圖片是 500*500*3 的話,那 tostring() 以後的長度是 500*500*3 後再乘上一個數  
  46.             #       print len(b)  
  47.             #  
  48.             #   但在網上有不少提供的代碼裏都沒有下面這一行,大家那真的能 reshape ?  
  49.             image = np.fromstring(image, dtype=np.float32)  
  50.             # reshape 成圖片矩陣  
  51.             image = tf.reshape(image, [224, 224, 3])  
  52.             # 由於要保存圖片,因此將其轉換成 uint8  
  53.             image = tf.image.convert_image_dtype(image, dtype=tf.uint8)  
  54.             # 按照 jpeg 格式編碼  
  55.             image = tf.image.encode_jpeg(image)  
  56.             # 保存圖片  
  57.             with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:  
  58.                 f.write(sess.run(image))  
  59.   
  60.   
  61. load_data(224, 224)  


train_list.txt_bak 中的內容以下:編碼

 

image_1093.jpg 13
image_0805.jpg 10spa

相關文章
相關標籤/搜索