TensorFlow數據集(一)——數據集的基本使用方法

參考書

《TensorFlow:實戰Google深度學習框架》(第2版)python

例子:從一個張量建立一個數據集,遍歷這個數據集,並對每一個輸入輸出y = x^2 的值。數組

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: 694317828@qq.com
@software: pycharm
@file: dataset_test1.py
@time: 2019/2/10 10:52
@desc: 例子:從一個張量建立一個數據集,遍歷這個數據集,並對每一個輸入輸出y = x^2 的值。
"""

import tensorflow as tf

# 從一個數組建立數據集。
input_data = [1, 2, 3, 5, 8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)

# 定義一個迭代器用於遍歷數據集。由於上面定義的數據集沒有用placeholder做爲輸入參數
# 因此這裏能夠使用最簡單的one_shot_iterator
iterator = dataset.make_one_shot_iterator()
# get_next() 返回表明一個輸入數據的張量,相似於隊列的dequeue()。
x = iterator.get_next()
y = x * x

with tf.Session() as sess:
    for i in range(len(input_data)):
        print(sess.run(y))

運行結果:


 

數據是文本文件:建立數據集。數據結構

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: 694317828@qq.com
@software: pycharm
@file: dataset_test2.py
@time: 2019/2/10 11:03
@desc: 數據是文本文件
"""

import tensorflow as tf

# 從文本文件建立數據集。假定每行文字是一個訓練例子。注意這裏能夠提供多個文件。
input_files = ['./input_file11', './input_file22']
dataset = tf.data.TextLineDataset(input_files)

# 定義迭代器用於遍歷數據集
iterator = dataset.make_one_shot_iterator()
# 這裏get_next()返回一個字符串類型的張量,表明文件中的一行。
x = iterator.get_next()
with tf.Session() as sess:
    for i in range(4):
        print(sess.run(x))

運行結果:


 

數據是TFRecord文件:建立TFRecord測試文件。app

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: 694317828@qq.com
@software: pycharm
@file: dataset_createdata.py
@time: 2019/2/10 13:59
@desc: 建立樣例文件
"""

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import time


# 生成整數型的屬性。
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 生成字符串型的屬性。
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


a = [11, 21, 31, 41, 51]
b = [22, 33, 44, 55, 66]


# 輸出TFRecord文件的地址
filename = './input_file2'
# 建立一個writer來寫TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(len(a)):
    aa = a[index]
    bb = b[index]
    # 將一個樣例轉化爲Example Protocol Buffer,並將全部的信息寫入這個數據結構。
    example = tf.train.Example(features=tf.train.Features(feature={
        'feat1': _int64_feature(aa),
        'feat2': _int64_feature(bb)
    }))

    # 將一個Example寫入TFRecord文件中。
    writer.write(example.SerializeToString())
writer.close()

運行結果:


 

數據是TFRecord文件:建立數據集。(使用最簡單的one_hot_iterator來遍歷數據集)框架

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: 694317828@qq.com
@software: pycharm
@file: dataset_test3.py
@time: 2019/2/10 13:16
@desc: 數據是TFRecord文件
"""

import tensorflow as tf


# 解析一個TFRecord的方法。record是從文件中讀取的一個樣例。前面介紹瞭如何解析TFRecord樣例。
def parser(record):
    # 解析讀入的一個樣例
    features = tf.parse_single_example(
        record,
        features={
            'feat1': tf.FixedLenFeature([], tf.int64),
            'feat2': tf.FixedLenFeature([], tf.int64),
        }
    )
    return features['feat1'], features['feat2']


# 從TFRecord文件建立數據集。
input_files = ['./input_file1', './input_file2']
dataset = tf.data.TFRecordDataset(input_files)

# map()函數表示對數據集中的每一條數據進行調用相應方法。使用TFRecordDataset讀出的是二進制的數據。
# 這裏須要經過map()函數來調用parser()對二進制數據進行解析。相似的,map()函數也能夠用來完成其餘的數據預處理工做。
dataset = dataset.map(parser)

# 定義遍歷數據集的迭代器
iterator = dataset.make_one_shot_iterator()

# feat1, feat2是parser()返回的一維int64型張量,能夠做爲輸入用於進一步的計算。
feat1, feat2 = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        f1, f2 = sess.run([feat1, feat2])
        print(f1, f2)

運行結果:


 

數據是TFRecord文件:建立數據集。(使用placeholder和initializable_iterator來動態初始化數據集) 函數

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: 694317828@qq.com
@software: pycharm
@file: dataset_test4.py
@time: 2019/2/10 13:44
@desc: 用initializable_iterator來動態初始化數據集的例子
"""

import tensorflow as tf
from figuredata_deal.dataset_test3 import parser


# 解析一個TFRecord的方法。與上面的例子相同再也不重複。
# 從TFRecord文件建立數據集,具體文件路徑是一個placeholder,稍後再提供具體路徑。
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parser)

# 定義遍歷dataset的initializable_iterator
iterator = dataset.make_initializable_iterator()
feat1, feat2 = iterator.get_next()

with tf.Session() as sess:
    # 首先初始化iterator,並給出input_files的值。
    sess.run(iterator.initializer, feed_dict={input_files: ['./input_file1', './input_file2']})

    # 遍歷全部數據一個epoch,當遍歷結束時,程序會拋出OutOfRangeError
    while True:
        try:
            sess.run([feat1, feat2])
        except tf.errors.OutOfRangeError:
            break

運行結果:

相關文章
相關標籤/搜索