TensorFlow劃分數據集

將數據劃分紅若干批次的數據,可使用tf.train或者tf.data.Dataset中的方法。html

1. tf.data.Dataset

(1)劃分方法

# 下面是,數據批次劃分

    batch_size = 10
    # 將訓練數據的特徵和標籤組合,使用from_tensor_slices將數據放入隊列
    dataset = tfdata.Dataset.from_tensor_slices((features, labels))
    # 使用shuffle(),隨機打亂數據集順序,不用shuffle就是按順序劃分,buffer_size 參數應大於等於樣本數
    # dataset = dataset.shuffle(buffer_size=num_examples)
    # batch把dataset按照batch_size分批次,獲得一個list集合。默認drop_remainder=False時,保留不足批次的部分,若是是True,就是捨去。
    dataset = dataset.batch(batch_size)
    # dataset = dataset.batch(batch_size).repeat()  # repeat表示重複次數,默認是None,表示數據序列無限延續

 

# 輸出

    # 輸出全部batch的list集合。
    # print(list(dataset.as_numpy_iterator()))

    # 輸出其中一個batch,兩種方法,官方推薦way2!
    print("way1")
    data_iter = iter(dataset)
    for X, y in data_iter:
        print(X, y)
        break
    print("way2")
    for (batch_num, (X, y)) in enumerate(dataset):
        print((X, y))  # batch_num是批次號,標識符,也能夠起其餘名字
        break

 

(2)dataset.batch()方法說明

batch把dataset按照batch_size分批次,獲得一個list集合。默認drop_remainder=False時,保留不足批次的部分,若是是True,就是捨去。
list(dataset.as_numpy_iterator())方法能夠輸出全部batch的list集合。
  def batch(self, batch_size, drop_remainder=False):
    """Combines consecutive elements of this dataset into batches.

    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3)
    >>> list(dataset.as_numpy_iterator()) #這個方法能夠輸出全部batch的list
    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]

    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3, drop_remainder=True)
    >>> list(dataset.as_numpy_iterator())
    [array([0, 1, 2]), array([3, 4, 5])]

(3)dataset.repeat()方法說明

  def repeat(self, count=None):
    """Repeats this dataset so each original value is seen `count` times.

    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
    >>> dataset = dataset.repeat(3)
    >>> list(dataset.as_numpy_iterator())
    [1, 2, 3, 1, 2, 3, 1, 2, 3]

    Note: If this dataset is a function of global state (e.g. a random number
    generator), then different repetitions may produce different elements.

    Args:
      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        number of times the dataset should be repeated. The default behavior (if `count` is `None` or `-1`) is for the dataset be repeated indefinitely.

    Returns:
      Dataset: A `Dataset`.
    """

 

2.tf.train

參考:https://www.cnblogs.com/jfl-xx/p/9945967.htmldom

相關文章
相關標籤/搜索