將數據劃分紅若干批次的數據,可使用tf.train或者tf.data.Dataset中的方法。html
# 下面是,數據批次劃分 batch_size = 10 # 將訓練數據的特徵和標籤組合,使用from_tensor_slices將數據放入隊列 dataset = tfdata.Dataset.from_tensor_slices((features, labels)) # 使用shuffle(),隨機打亂數據集順序,不用shuffle就是按順序劃分,buffer_size 參數應大於等於樣本數 # dataset = dataset.shuffle(buffer_size=num_examples) # batch把dataset按照batch_size分批次,獲得一個list集合。默認drop_remainder=False時,保留不足批次的部分,若是是True,就是捨去。 dataset = dataset.batch(batch_size) # dataset = dataset.batch(batch_size).repeat() # repeat表示重複次數,默認是None,表示數據序列無限延續
# 輸出 # 輸出全部batch的list集合。 # print(list(dataset.as_numpy_iterator())) # 輸出其中一個batch,兩種方法,官方推薦way2! print("way1") data_iter = iter(dataset) for X, y in data_iter: print(X, y) break print("way2") for (batch_num, (X, y)) in enumerate(dataset): print((X, y)) # batch_num是批次號,標識符,也能夠起其餘名字 break
batch把dataset按照batch_size分批次,獲得一個list集合。默認drop_remainder=False時,保留不足批次的部分,若是是True,就是捨去。
用list(dataset.as_numpy_iterator())方法能夠輸出全部batch的list集合。
def batch(self, batch_size, drop_remainder=False): """Combines consecutive elements of this dataset into batches. >>> dataset = tf.data.Dataset.range(8) >>> dataset = dataset.batch(3) >>> list(dataset.as_numpy_iterator()) #這個方法能夠輸出全部batch的list [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])] >>> dataset = tf.data.Dataset.range(8) >>> dataset = dataset.batch(3, drop_remainder=True) >>> list(dataset.as_numpy_iterator()) [array([0, 1, 2]), array([3, 4, 5])]
def repeat(self, count=None): """Repeats this dataset so each original value is seen `count` times. >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) >>> dataset = dataset.repeat(3) >>> list(dataset.as_numpy_iterator()) [1, 2, 3, 1, 2, 3, 1, 2, 3] Note: If this dataset is a function of global state (e.g. a random number generator), then different repetitions may produce different elements. Args: count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the number of times the dataset should be repeated. The default behavior (if `count` is `None` or `-1`) is for the dataset be repeated indefinitely. Returns: Dataset: A `Dataset`. """