Tensorflow 數據導入

導入數據

藉助 tf.data API,您能夠根據簡單的可重用片斷構建複雜的輸入管道。例如,圖片模型的管道可能會匯聚分佈式文件系統中的文件中的數據、對每一個圖片應用隨機擾動,並將隨機選擇的圖片合併成用於訓練的批次。文本模型的管道可能包括從原始文本數據中提取符號、根據對照表將其轉換爲嵌入標識符,以及將不一樣長度的序列組合成批次數據。使用 tf.data API 能夠輕鬆處理大量數據、不一樣的數據格式以及複雜的轉換。html

tf.data API 在 TensorFlow 中引入了兩個新的抽象類:python

  • tf.data.Dataset 表示一系列元素,其中每一個元素包含一個或多個 Tensor 對象。例如,在圖像管道中,元素多是單個訓練樣本,具備一對錶示圖像數據和標籤的張量。能夠經過兩種不一樣的方式來建立數據集:算法

    • 建立來源(例如 Dataset.from_tensor_slices()),以經過一個或多個 tf.Tensor 對象構建數據集。編程

    • 應用轉換(例如 Dataset.batch()),以經過一個或多個 tf.data.Dataset 對象構建數據集。api

  • tf.data.Iterator 提供了從數據集中提取元素的主要方法。Iterator.get_next() 返回的操做會在執行時生成 Dataset 的下一個元素,而且此操做一般充當輸入管道代碼和模型之間的接口。最簡單的迭代器是「單次迭代器」,它與特定的 Dataset 相關聯,並對其進行一次迭代。要實現更復雜的用途,您能夠經過 Iterator.initializer 操做使用不一樣的數據集從新初始化和參數化迭代器,這樣一來,您就能夠在同一個程序中對訓練和驗證數據進行屢次迭代(舉例而言)。數組

基本機制

本指南的這一部分介紹了建立不一樣種類的 DatasetIterator 對象的基礎知識,以及如何從這些對象中提取數據。網絡

要啓動輸入管道,您必須定義來源。例如,要經過內存中的某些張量構建 Dataset,您可使用 tf.data.Dataset.from_tensors()tf.data.Dataset.from_tensor_slices()。或者,若是輸入數據以推薦的 TFRecord 格式存儲在磁盤上,那麼您能夠構建 tf.data.TFRecordDatasetapp

一旦有了 Dataset 對象,能夠將其轉換爲新的 Dataset,方法是連接tf.data.Dataset 對象上的方法調用。例如,您能夠應用單元素轉換,例如 Dataset.map()(爲每一個元素應用一個函數),也能夠應用多元素轉換(例如 Dataset.batch())。要了解轉換的完整列表,請參閱 tf.data.Dataset 的文檔。框架

消耗 Dataset 中值的最多見方法是構建迭代器對象。經過此對象,能夠一次訪問數據集中的一個元素(例如經過調用 Dataset.make_one_shot_iterator())。tf.data.Iterator 提供了兩個操做:Iterator.initializer,您能夠經過此操做(從新)初始化迭代器的狀態;以及 Iterator.get_next(),此操做返回對應於有符號下一個元素的 tf.Tensor 對象。根據您的使用情形,您能夠選擇不一樣類型的迭代器,下文介紹了具體選項。dom

數據集結構

一個數據集包含多個元素,每一個元素的結構都相同。一個元素包含一個或多個 tf.Tensor 對象,這些對象稱爲組件。每一個組件都有一個 tf.DType,表示張量中元素的類型;以及一個 tf.TensorShape,表示每一個元素(可能部分指定)的靜態形狀。您能夠經過 Dataset.output_typesDataset.output_shapes 屬性檢查數據集元素各個組件的推理類型和形狀。這些屬性的嵌套結構映射到元素的結構,此元素能夠是單個張量、張量元組,也能夠是張量的嵌套元組。例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random_uniform([4]),
    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

爲元素的每一個組件命名一般會帶來便利性,例如,若是它們表示訓練樣本的不一樣特徵。除了元組以外,還可使用 collections.namedtuple 或將字符串映射到張量的字典來表示 Dataset 的單個元素。

dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

Dataset 轉換支持任何結構的數據集。在使用 Dataset.map()Dataset.flat_map()Dataset.filter() 轉換時(這些轉換會對每一個元素應用一個函數),元素結構決定了函數的參數:

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

建立迭代器

構建了表示輸入數據的 Dataset 後,下一步就是建立 Iterator 來訪問該數據集中的元素。tf.data API 目前支持下列迭代器,複雜程度逐漸增大:

  • 單次
  • 可初始化
  • 可從新初始化,以及
  • 可饋送

單次迭代器是最簡單的迭代器形式,僅支持對數據集進行一次迭代,不須要顯式初始化。單次迭代器能夠處理基於隊列的現有輸入管道支持的幾乎全部狀況,但它們不支持參數化。以 Dataset.range() 爲例:

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

注意:目前,單次迭代器是惟一易於與 Estimator 搭配使用的類型。

您須要先運行顯式 iterator.initializer 操做,而後才能使用可初始化迭代器。雖然有些不便,但它容許您使用一個或多個 tf.placeholder() 張量(可在初始化迭代器時饋送)參數化數據集的定義。繼續以 Dataset.range() 爲例:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

可從新初始化迭代器能夠經過多個不一樣的 Dataset 對象進行初始化。例如,您可能有一個訓練輸入管道,它會對輸入圖片進行隨機擾動來改善泛化;還有一個驗證輸入管道,它會評估對未修改數據的預測。這些管道一般會使用不一樣的 Dataset 對象,這些對象具備相同的結構(即每一個組件具備相同類型和兼容形狀)。

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

可饋送迭代器能夠與 tf.placeholder 一塊兒使用,以選擇所使用的 Iterator(在每次調用 tf.Session.run 時)(經過熟悉的 feed_dict 機制)。它提供的功能與可從新初始化迭代器的相同,但在迭代器之間切換時不須要從數據集的開頭初始化迭代器。例如,以上面的同一訓練和驗證數據集爲例,您可使用 tf.data.Iterator.from_string_handle 定義一個可以讓您在兩個數據集之間切換的可饋送迭代器:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})

消耗迭代器中的值

Iterator.get_next() 方法返回一個或多個 tf.Tensor 對象,這些對象對應於迭代器有符號的下一個元素。每次評估這些張量時,它們都會獲取底層數據集中下一個元素的值。(請注意,與 TensorFlow 中的其餘有狀態對象同樣,調用 Iterator.get_next() 並不會當即使迭代器進入下個狀態。您必須在 TensorFlow 表達式中使用此函數返回的 tf.Tensor 對象,並將該表達式的結果傳遞到 tf.Session.run(),以獲取下一個元素並使迭代器進入下個狀態。)

若是迭代器到達數據集的末尾,則執行 Iterator.get_next() 操做會產生 tf.errors.OutOfRangeError。在此以後,迭代器將處於不可用狀態;若是須要繼續使用,則必須對其從新初始化。

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)

sess.run(iterator.initializer)
print(sess.run(result))  # ==> "0"
print(sess.run(result))  # ==> "2"
print(sess.run(result))  # ==> "4"
print(sess.run(result))  # ==> "6"
print(sess.run(result))  # ==> "8"
try:
  sess.run(result)
except tf.errors.OutOfRangeError:
  print("End of dataset")  # ==> "End of dataset"

一種常見模式是將「訓練循環」封裝在 try-except 塊中:

sess.run(iterator.initializer)
while True:
  try:
    sess.run(result)
  except tf.errors.OutOfRangeError:
    break

若是數據集的每一個元素都具備嵌套結構,則 Iterator.get_next() 的返回值將是一個或多個 tf.Tensor 對象,這些對象具備相同的嵌套結構:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()

請注意,next1next2next3 是由同一個操做/節點(經過 Iterator.get_next() 建立)生成的張量。所以,評估其中任何一個張量都會使全部組件的迭代器進入下個狀態。典型的迭代器消耗方會在一個表達式中包含全部組件。

保存迭代器狀態

tf.contrib.data.make_saveable_from_iterator 函數經過迭代器建立一個 SaveableObject,該對象可用於保存和恢復迭代器(其實是整個輸入管道)的當前狀態。以這種方式建立的可保存對象能夠添加到 tf.train.Saver 變量列表或 tf.GraphKeys.SAVEABLE_OBJECTS 集合中,以便採用與 tf.Variable 相同的方式進行保存和恢復。請參閱保存和恢復,詳細瞭解如何保存和恢復變量。

# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)

# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()

with tf.Session() as sess:

  if should_checkpoint:
    saver.save(path_to_checkpoint)

# Restore the iterator state.
with tf.Session() as sess:
  saver.restore(sess, path_to_checkpoint)

讀取輸入數據

消耗 NumPy 數組

若是您的全部輸入數據都適合存儲在內存中,則根據輸入數據建立 Dataset 的最簡單方法是將它們轉換爲 tf.Tensor 對象,並使用 Dataset.from_tensor_slices()

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

請注意,上面的代碼段會將 featureslabels 數組做爲 tf.constant() 指令嵌入在 TensorFlow 圖中。這樣很是適合小型數據集,但會浪費內存,由於會屢次複製數組的內容,並可能會達到 tf.GraphDef 協議緩衝區的 2GB 上限。

做爲替代方案,您能夠根據 tf.placeholder() 張量定義 Dataset,並在對數據集初始化 Iterator 時饋送 NumPy 數組。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

消耗 TFRecord 數據

tf.data API 支持多種文件格式,所以您能夠處理那些不適合存儲在內存中的大型數據集。例如,TFRecord 文件格式是一種面向記錄的簡單二進制格式,不少 TensorFlow 應用採用此格式來訓練數據。經過 tf.data.TFRecordDataset 類,您能夠將一個或多個 TFRecord 文件的內容做爲輸入管道的一部分進行流式傳輸。

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

TFRecordDataset 初始化程序的 filenames 參數能夠是字符串、字符串列表,也能夠是字符串 tf.Tensor。所以,若是您有兩組分別用於訓練和驗證的文件,則可使用 tf.placeholder(tf.string) 來表示文件名,並使用適當的文件名初始化迭代器:

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # Parse the record into tensors.
dataset = dataset.repeat()  # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()

# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.

# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})

消耗文本數據

不少數據集都是做爲一個或多個文本文件分佈的。tf.data.TextLineDataset 提供了一種從一個或多個文本文件中提取行的簡單方法。給定一個或多個文件名,TextLineDataset 會爲這些文件的每行生成一個字符串值元素。像 TFRecordDataset 同樣,TextLineDataset 將接受 filenames(做爲 tf.Tensor),所以您能夠經過傳遞 tf.placeholder(tf.string) 進行參數化。

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

默認狀況下,TextLineDataset 會生成每一個文件的每一行,這多是不可取的(例如,若是文件以標題行開頭或包含註釋)。可使用 Dataset.skip()Dataset.filter() 轉換來移除這些行。爲了將這些轉換分別應用於每一個文件,咱們使用 Dataset.flat_map() 爲每一個文件建立一個嵌套的 Dataset

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.data.Dataset.from_tensor_slices(filenames)

# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
    lambda filename: (
        tf.data.TextLineDataset(filename)
        .skip(1)
        .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))

消耗 CSV 數據

CSV 文件格式是用於以純文本格式存儲表格數據的經常使用格式。tf.contrib.data.CsvDataset 類提供了一種從符合 RFC 4180 的一個或多個 CSV 文件中提取記錄的方法。給定一個或多個文件名以及默認值列表後,CsvDataset 將生成一個元素元組,元素類型對應於爲每一個 CSV 記錄提供的默認元素類型。像 TFRecordDatasetTextLineDataset 同樣,CsvDataset 將接受 filenames(做爲 tf.Tensor),所以您能夠經過傳遞 tf.placeholder(tf.string) 進行參數化。

# Creates a dataset that reads all of the records from two CSV files, each with
# eight float columns
filenames = ["/var/data/file1.csv", "/var/data/file2.csv"]
record_defaults = [tf.float32] * 8   # Eight required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

若是某些列爲空,則能夠提供默認值而不是類型。

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values
record_defaults = [[0.0]] * 8
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

默認狀況下,CsvDataset 生成文件的每一列或每一行,這多是不可取的;例如,若是文件以應忽略的標題行開頭,或若是輸入中不須要某些列。能夠分別使用 headerselect_cols 參數移除這些行和字段。

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [[0.0]] * 2  # Only provide defaults for the selected columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[2,4])

使用 Dataset.map() 預處理數據

Dataset.map(f) 轉換經過將指定函數 f 應用於輸入數據集的每一個元素來生成新數據集。此轉換基於 map() 函數(一般應用於函數式編程語言中的列表和其餘結構)。函數 f 會接受表示輸入中單個元素的 tf.Tensor 對象,並返回表示新數據集中單個元素的 tf.Tensor 對象。此函數的實現使用標準的 TensorFlow 指令將一個元素轉換爲另外一個元素。

本部分介紹瞭如何使用 Dataset.map() 的常見示例。

解析 tf.Example 協議緩衝區消息

許多輸入管道都從 TFRecord 格式的文件中提取 tf.train.Example 協議緩衝區消息(例如這種文件使用 tf.python_io.TFRecordWriter 編寫而成)。每一個 tf.train.Example 記錄都包含一個或多個「特徵」,輸入管道一般會將這些特徵轉換爲張量。

# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

解碼圖片數據並調整其大小

在用真實的圖片數據訓練神經網絡時,一般須要將不一樣大小的圖片轉換爲通用大小,這樣就能夠將它們批處理爲具備固定大小的數據。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

使用 tf.py_func() 應用任意 Python 邏輯

爲了確保性能,咱們建議您儘量使用 TensorFlow 指令預處理數據。不過,在解析輸入數據時,調用外部 Python 庫有時頗有用。爲此,請在 Dataset.map() 轉換中調用 tf.py_func() 指令。

import cv2

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
  image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
  return image_decoded, label

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
    lambda filename, label: tuple(tf.py_func(
        _read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

批處理數據集元素

簡單的批處理

最簡單的批處理形式是將數據集中的 n 個連續元素堆疊爲一個元素。Dataset.batch() 轉換正是這麼作的,它與 tf.stack() 運算符具備相同的限制(被應用於元素的每一個組件):即對於每一個組件 i,全部元素的張量形狀都必須徹底相同。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
print(sess.run(next_element))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])

使用填充批處理張量

上述方法適用於具備相同大小的張量。不過,不少模型(例如序列模型)處理的輸入數據可能具備不一樣的大小(例如序列的長度不一樣)。爲了解決這種狀況,能夠經過 Dataset.padded_batch() 轉換來指定一個或多個會被填充的維度,從而批處理不一樣形狀的張量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element))  # ==> [[4, 4, 4, 4, 0, 0, 0],
                               #      [5, 5, 5, 5, 5, 0, 0],
                               #      [6, 6, 6, 6, 6, 6, 0],
                               #      [7, 7, 7, 7, 7, 7, 7]]

您能夠經過 Dataset.padded_batch() 轉換爲每一個組件的每一個維度設置不一樣的填充,而且能夠採用可變長度(在上面的示例中用 None 表示)或恆定長度。也能夠替換填充值,默認設置爲 0。

訓練工做流程

處理多個週期

tf.data API 提供了兩種主要方式來處理同一數據的多個週期。

要迭代數據集多個週期,最簡單的方法是使用 Dataset.repeat() 轉換。例如,要建立一個將其輸入重複 10 個週期的數據集:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)

應用不帶參數的 Dataset.repeat() 轉換將無限次地重複輸入。Dataset.repeat() 轉換將其參數鏈接起來,而不會在一個週期結束和下一個週期開始時發出信號。

若是您想在每一個週期結束時收到信號,則能夠編寫在數據集結束時捕獲 tf.errors.OutOfRangeError 的訓練循環。此時,您能夠收集關於該週期的一些統計信息(例如驗證錯誤)。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Compute for 100 epochs.
for _ in range(100):
  sess.run(iterator.initializer)
  while True:
    try:
      sess.run(next_element)
    except tf.errors.OutOfRangeError:
      break

  # [Perform end-of-epoch calculations here.]

隨機重排輸入數據

Dataset.shuffle() 轉換會使用相似於 tf.RandomShuffleQueue 的算法隨機重排輸入數據集:它會維持一個固定大小的緩衝區,並從該緩衝區統一地隨機選擇下一個元素。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

使用高階 API

tf.train.MonitoredTrainingSession API 簡化了在分佈式設置下運行 TensorFlow 的不少方面。MonitoredTrainingSession 使用 tf.errors.OutOfRangeError 表示訓練已完成,所以要將其與 tf.data API 結合使用,咱們建議使用 Dataset.make_one_shot_iterator()。例如:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop():
    sess.run(training_op)

要在 input_fn 中使用 Dataset(input_fn 屬於 tf.estimator.Estimator),只需返回 Dataset 便可,框架將負責爲您建立和初始化迭代器。例如:

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`
  # protocol buffer, and perform any additional per-record preprocessing.
  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.
    image = tf.image.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label
  # tensor for each example.
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)

  # Each element of `dataset` is tuple containing a dictionary of features
  # (in which each value is a batch of values for that feature), and a batch of
  # labels.
  return dataset
相關文章
相關標籤/搜索