day21 TFRecord格式轉換MNIST並顯示

       首先簡要介紹了下TFRecord格式以及內部實現protobuf協議,而後基於TFRecord格式,對MNIST數據集轉換成TFRecord格式,寫入本地磁盤文件,再從磁盤文件讀取,經過pyplot模塊現實在界面上,效果圖以下:python

TFRecord和Protobuf協議簡介

        TFRecord是谷歌專門爲Tensorflow打造的一種存儲格式,基於protobuf協議實現,也是谷歌推薦的,一個主要緣由是作到訓練和驗證數據格式的統一,有助於不一樣開發者快速遷移模型。Google Protocol Buffer( 簡稱 Protobuf) 是 Google 公司內部的混合語言數據標準,目前已經正在使用的有超過 48,162 種報文格式定義和超過 12,183 個 .proto 文件。他們用於 RPC 系統和持續數據存儲系統。Protocol Buffers 是一種輕便高效的結構化數據存儲格式,能夠用於結構化數據串行化,或者說序列化。它很適合作數據存儲或 RPC 數據交換格式。可用於通信協議、數據存儲等領域的語言無關、平臺無關、可擴展的序列化結構數據格式。目前提供了 C++、Java、Python 三種語言的 API。json

 

優勢數據結構

缺點工具

Protobuf學習

一、Protobuf 有如 XML,不過它更小、更快(幾十倍於XML和JOSON、也更簡單ui

二、「向後」兼容性好spa

三、 Protobuf 語義更清晰,無需相似 XML 解析器的東西code

四、使用 Protobuf 無需學習複雜的文檔對象模型xml

一、 功能簡單,沒法用來表示複雜的概念對象

二、Protobuf 只是 Google 公司內部使用的工具,在通用性上還差不少

三、因爲文本並不適合用來描述數據結構,因此 Protobuf 也不適合用來對基於文本的標記文檔(如 HTML)

四、除非你有 .proto 定義,不然你無法直接讀出 Protobuf 的任何內容

        下面舉個簡單的例子,從數據的存儲格式的角度進行對比,假如要存儲一個鍵值對:{price:150}

  protobuf的表示方式以下,protobuf的物理存儲:08 96 01,就3個字節。採用key-value的方式存放,第一個字節是key,它是field_number << 3 | wire_type構成。因此field number是1,wire type是0,即varint,有了這個wire type就能夠用來解析96 01了。

message  Test {
    optional int32 price = 1;
}

       xml的存儲表示以下,大約須要36字節。

<some>
  <name>price</name>
  <value>150</value>
</some>

         json的存儲表示以下,大約須要11字節。

{price:150}

        綜上所述,protobuf相比於json和xml,對象序列化時能夠節省很是大的空間,從而帶來很是快的傳輸速度。

利用TFRecord格式存儲、讀取和現實MNIST數據集

        在TensorFlow中,TFRecord格式是經過tf.train.Example Protocol Buffer協議的存儲的,如下代碼給出了tf.train.Example的定義:

message Example {
  Features features = 1;
};

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};

message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

message BytesList {
  repeated bytes value = 1;
}

message FloatList {
  repeated float value = 1 [packed = true];
}

message Int64List {
  repeated int64 value = 1 [packed = true];
}

       下面給出兩個代碼實例:一個程序ToTFRecord.py從MNIST數據集中讀取圖像和標籤集,而後經過TFRecord格式文件中,另外一個程序FromTFRecord.py從文件中讀取TFRecord格式圖像,而後經過pylot模塊顯示在界面上。

        ToTFRecord.py:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist
import numpy as np

def __int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))

def __bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))

# 讀取圖像和標籤
mnist_data = mnist.read_data_sets(train_dir='MNIST_data/',dtype=tf.uint8,one_hot=True)
images = mnist_data.train.images
labels = mnist_data.train.labels
pixels = images.shape[1]
num_examples = mnist_data.train.num_examples

filename = "MNIST_TFRecord/output.tfrecords"
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
    image_raw = images[index].tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels':__int64_feature(pixels),
        'label':__int64_feature(np.argmax(labels[index])),
        'image_raw':__bytes_feature(image_raw)
    }))
    writer.write(example.SerializeToString())
writer.close()

        FromTFRecord.py:

import tensorflow as tf

# 讀取一個樣例
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(["MNIST_TFRecord/output.tfrecords"])
_,serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized=serialized_example,features={
    'image_raw':tf.FixedLenFeature([],tf.string),
    'pixels':tf.FixedLenFeature([],tf.int64),
    'label':tf.FixedLenFeature([],tf.int64)
})

# 從樣例中解析數據
images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)

from matplotlib import pyplot as plt
import time
import datetime
from six.moves import xrange  # pylint: disable=redefined-builtin
fig, ax = plt.subplots(2, 5, figsize=[3, 3])
plt.ion()
plt.axis('off')
print("%s :建立10個窗口成功..."%(datetime.datetime.now()))

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess,coord)
    try:
        for step in xrange(10):
            if coord.should_stop():
                break
            for i in range(2):
                for j in range(5):
                    cur_pic = i * 5 + j
                    image, label, pixel = sess.run([images, labels, pixels])
                    image = image.reshape([28, 28])
                    print(image, label, pixel)
                    ax[i, j].imshow(image, cmap=plt.cm.gray)
            plt.show()
            plt.pause(2)
    except Exception:
        # Report exceptions to the coordinator.
        coord.request_stop()

    # Terminate as usual.  It is innocuous to request stop twice.
    coord.request_stop()
    coord.join(threads)
相關文章
相關標籤/搜索