首先簡要介紹了下TFRecord格式以及內部實現protobuf協議,而後基於TFRecord格式,對MNIST數據集轉換成TFRecord格式,寫入本地磁盤文件,再從磁盤文件讀取,經過pyplot模塊現實在界面上,效果圖以下:python
TFRecord是谷歌專門爲Tensorflow打造的一種存儲格式,基於protobuf協議實現,也是谷歌推薦的,一個主要緣由是作到訓練和驗證數據格式的統一,有助於不一樣開發者快速遷移模型。Google Protocol Buffer( 簡稱 Protobuf) 是 Google 公司內部的混合語言數據標準,目前已經正在使用的有超過 48,162 種報文格式定義和超過 12,183 個 .proto 文件。他們用於 RPC 系統和持續數據存儲系統。Protocol Buffers 是一種輕便高效的結構化數據存儲格式,能夠用於結構化數據串行化,或者說序列化。它很適合作數據存儲或 RPC 數據交換格式。可用於通信協議、數據存儲等領域的語言無關、平臺無關、可擴展的序列化結構數據格式。目前提供了 C++、Java、Python 三種語言的 API。json
|
優勢數據結構 |
缺點工具 |
Protobuf學習 |
一、Protobuf 有如 XML,不過它更小、更快(幾十倍於XML和JOSON)、也更簡單ui 二、「向後」兼容性好spa 三、 Protobuf 語義更清晰,無需相似 XML 解析器的東西code 四、使用 Protobuf 無需學習複雜的文檔對象模型xml |
一、 功能簡單,沒法用來表示複雜的概念對象 二、Protobuf 只是 Google 公司內部使用的工具,在通用性上還差不少 三、因爲文本並不適合用來描述數據結構,因此 Protobuf 也不適合用來對基於文本的標記文檔(如 HTML) 四、除非你有 .proto 定義,不然你無法直接讀出 Protobuf 的任何內容 |
下面舉個簡單的例子,從數據的存儲格式的角度進行對比,假如要存儲一個鍵值對:{price:150}
protobuf的表示方式以下,protobuf的物理存儲:08 96 01,就3個字節。採用key-value的方式存放,第一個字節是key,它是field_number << 3 | wire_type構成。因此field number是1,wire type是0,即varint,有了這個wire type就能夠用來解析96 01了。
message Test { optional int32 price = 1; }
xml的存儲表示以下,大約須要36字節。
<some> <name>price</name> <value>150</value> </some>
json的存儲表示以下,大約須要11字節。
{price:150}
綜上所述,protobuf相比於json和xml,對象序列化時能夠節省很是大的空間,從而帶來很是快的傳輸速度。
在TensorFlow中,TFRecord格式是經過tf.train.Example Protocol Buffer協議的存儲的,如下代碼給出了tf.train.Example的定義:
message Example { Features features = 1; }; message Features { // Map from feature name to feature. map<string, Feature> feature = 1; }; message Feature { // Each feature can be exactly one kind. oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } }; message BytesList { repeated bytes value = 1; } message FloatList { repeated float value = 1 [packed = true]; } message Int64List { repeated int64 value = 1 [packed = true]; }
下面給出兩個代碼實例:一個程序ToTFRecord.py從MNIST數據集中讀取圖像和標籤集,而後經過TFRecord格式文件中,另外一個程序FromTFRecord.py從文件中讀取TFRecord格式圖像,而後經過pylot模塊顯示在界面上。
ToTFRecord.py:
import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets import mnist import numpy as np def __int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value=[value])) def __bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value])) # 讀取圖像和標籤 mnist_data = mnist.read_data_sets(train_dir='MNIST_data/',dtype=tf.uint8,one_hot=True) images = mnist_data.train.images labels = mnist_data.train.labels pixels = images.shape[1] num_examples = mnist_data.train.num_examples filename = "MNIST_TFRecord/output.tfrecords" writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples): image_raw = images[index].tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'pixels':__int64_feature(pixels), 'label':__int64_feature(np.argmax(labels[index])), 'image_raw':__bytes_feature(image_raw) })) writer.write(example.SerializeToString()) writer.close()
FromTFRecord.py:
import tensorflow as tf # 讀取一個樣例 reader = tf.TFRecordReader() filename_queue = tf.train.string_input_producer(["MNIST_TFRecord/output.tfrecords"]) _,serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized=serialized_example,features={ 'image_raw':tf.FixedLenFeature([],tf.string), 'pixels':tf.FixedLenFeature([],tf.int64), 'label':tf.FixedLenFeature([],tf.int64) }) # 從樣例中解析數據 images = tf.decode_raw(features['image_raw'],tf.uint8) labels = tf.cast(features['label'],tf.int32) pixels = tf.cast(features['pixels'],tf.int32) from matplotlib import pyplot as plt import time import datetime from six.moves import xrange # pylint: disable=redefined-builtin fig, ax = plt.subplots(2, 5, figsize=[3, 3]) plt.ion() plt.axis('off') print("%s :建立10個窗口成功..."%(datetime.datetime.now())) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: for step in xrange(10): if coord.should_stop(): break for i in range(2): for j in range(5): cur_pic = i * 5 + j image, label, pixel = sess.run([images, labels, pixels]) image = image.reshape([28, 28]) print(image, label, pixel) ax[i, j].imshow(image, cmap=plt.cm.gray) plt.show() plt.pause(2) except Exception: # Report exceptions to the coordinator. coord.request_stop() # Terminate as usual. It is innocuous to request stop twice. coord.request_stop() coord.join(threads)