只有光頭才能變強。php
文本已收錄至個人GitHub倉庫,歡迎Star:https://github.com/ZhongFuCheng3y/3ypython
回顧前面:git
衆所周知,要訓練出一個模型,首先咱們得有數據。咱們第一個例子中,直接使用dataset的api去加載mnist的數據。(minst的數據要麼咱們是提早下載好,放在對應的目錄上,要麼就根據他給的url直接從網上下載)。github
通常來講,咱們使用TensorFlow是從TFRecord文件中讀取數據的。api
TFRecord 文件格式是一種面向記錄的簡單二進制格式,不少 TensorFlow 應用採用此格式來訓練數據網絡
因此,這篇文章來聊聊怎麼讀取TFRecord文件的數據。session
首先,咱們來體驗一下怎麼造一個TFRecord文件,怎麼從TFRecord文件中讀取數據,遍歷(消費)這些數據。數據結構
如今,咱們尚未TFRecord文件,咱們能夠本身簡單寫一個:機器學習
def write_sample_to_tfrecord(): gmv_values = np.arange(10) click_values = np.arange(10) label_values = np.arange(10) with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer: for _ in range(10): feature_internal = { "gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])), "click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])), "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]])) } features_extern = tf.train.Features(feature=feature_internal) # 使用tf.train.Example將features編碼數據封裝成特定的PB協議格式 # example = tf.train.Example(features=tf.train.Features(feature=features_extern)) example = tf.train.Example(features=features_extern) # 將example數據系列化爲字符串 example_str = example.SerializeToString() # 將系列化爲字符串的example數據寫入協議緩衝區 writer.write(example_str) if __name__ == '__main__': write_sample_to_tfrecord()
我相信你們代碼應該是可以看得懂的,其實就是分了幾步:函數
參考資料:
ok,如今咱們就有了一個TFRecord文件啦。
其實就是經過tf.data.TFRecordDataset
這個api來讀取到TFRecord文件,生成處dataset對象
對dataset進行處理(shape處理,格式處理...等等)
使用迭代器對dataset進行消費(遍歷)
demo代碼以下:
import tensorflow as tf def read_tensorflow_tfrecord_files(): # 定義消費緩衝區協議的parser,做爲dataset.map()方法中傳入的lambda: def _parse_function(single_sample): features = { "gmv": tf.FixedLenFeature([1], tf.float32), "click": tf.FixedLenFeature([1], tf.int64), # ()或者[]沒啥影響 "label": tf.FixedLenFeature([1], tf.int64) } parsed_features = tf.parse_single_example(single_sample, features=features) # 對parsed 以後的值進行cast. gmv = tf.cast(parsed_features["gmv"], tf.float64) click = tf.cast(parsed_features["click"], tf.float64) label = tf.cast(parsed_features["label"], tf.float64) return gmv, click, label # 開始定義dataset以及解析tfrecord格式 filenames = tf.placeholder(tf.string, shape=[None]) # 定義dataset 和 一些列trasformation method dataset = tf.data.TFRecordDataset(filenames) parsed_dataset = dataset.map(_parse_function) # 消費緩衝區須要定義在dataset 的map 函數中 batchd_dataset = parsed_dataset.batch(3) # 建立Iterator sample_iter = batchd_dataset.make_initializable_iterator() # 獲取next_sample gmv, click, label = sample_iter.get_next() training_filenames = [ "/Users/zhongfucheng/data/fashin/demo.tfrecord"] with tf.Session() as session: # 初始化帶參數的Iterator session.run(sample_iter.initializer, feed_dict={filenames: training_filenames}) # 讀取文件 print(session.run(gmv)) if __name__ == '__main__': read_tensorflow_tfrecord_files()
無心外的話,咱們能夠輸出這樣的結果:
[[0.] [1.] [2.]]
ok,如今咱們已經大概知道怎麼寫一個TFRecord文件,以及怎麼讀取TFRecord文件的數據,而且消費這些數據了。
我在學習TensorFlow翻閱資料時,常常看到一些機器學習的術語,因爲本身沒啥機器學習的基礎,因此不少時候看到一些專業名詞就開始懵逼了。
當一個完整的數據集經過了神經網絡一次而且返回了一次,這個過程稱爲一個epoch。
這可能使咱們跟dataset.repeat()
方法聯繫起來,這個方法可使當前數據集重複一遍。好比說,原有的數據集是[1,2,3,4,5]
,若是我調用dataset.repeat(2)
的話,那麼咱們的數據集就變成了[1,2,3,4,5],[1,2,3,4,5]
通常來講咱們的數據集都是比較大的,沒法一次性將整個數據集的數據喂進神經網絡中,因此咱們會將數據集分紅好幾個部分。每次喂多少條樣本進神經網絡,這個叫作batchSize。
在TensorFlow也提供了方法給咱們設置:dataset.batch()
,在API中是這樣介紹batchSize的:
representing the number of consecutive elements of this dataset to combine in a single batch
咱們通常在每次訓練以前,會將整個數據集的順序打亂,提升咱們模型訓練的效果。這裏咱們用到的api是:dataset.shffle();
我從官網的介紹中截了一個dataset的方法圖(部分):
dataset的功能主要有如下三種:
map(),flat_map(),zip(),repeat()
等等迭代器能夠分爲四種:
tf.data.Iterator.from_structure
來進行初始化簡單總結:
string handler(可饋送的 Iterator)這種方式是最常使用的,我當時也寫了一個Demo來使用了一下,代碼以下:
def read_tensorflow_tfrecord_files(): # 開始定義dataset以及解析tfrecord格式. train_filenames = tf.placeholder(tf.string, shape=[None]) vali_filenames = tf.placeholder(tf.string, shape=[None]) # 加載train_dataset batch_inputs這個方法每一個人都不同的,這個方法我就不給了。 train_dataset = batch_inputs([ train_filenames], batch_size=5, type=False, num_epochs=2, num_preprocess_threads=3) # 加載validation_dataset batch_inputs這個方法每一個人都不同的,這個方法我就不給了。 validation_dataset = batch_inputs([vali_filenames ], batch_size=5, type=False, num_epochs=2, num_preprocess_threads=3) # 建立出string_handler()的迭代器(經過相同數據結構的dataset來構建) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_dataset.output_types, train_dataset.output_shapes) # 有了迭代器就能夠調用next方法了。 itemid = iterator.get_next() # 指定哪一種具體的迭代器,有單次迭代的,有初始化的。 training_iterator = train_dataset.make_initializable_iterator() validation_iterator = validation_dataset.make_initializable_iterator() # 定義出placeholder的值 training_filenames = [ "/Users/zhongfucheng/tfrecord_test/data01aa"] validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"] with tf.Session() as sess: # 初始化迭代器 training_handle = sess.run(training_iterator.string_handle()) validation_handle = sess.run(validation_iterator.string_handle()) for _ in range(2): sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames}) print("this is training iterator ----") for _ in range(5): print(sess.run(itemid, feed_dict={handle: training_handle})) sess.run(validation_iterator.initializer, feed_dict={vali_filenames: validation_filenames}) print("this is validation iterator ") for _ in range(5): print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle})) if __name__ == '__main__': read_tensorflow_tfrecord_files()
參考資料:
在翻閱資料時,發現寫得不錯的一些博客:
樂於輸出乾貨的Java技術公衆號:Java3y。公衆號內有200多篇原創技術文章、海量視頻資源、精美腦圖,不妨來關注一下!
下一篇文章打算講講如何理解axis~
以爲個人文章寫得不錯,不妨點一下贊!