機器學習中數據讀取是很重要的一個環節,TensorFlow也提供了不少實用的方法,爲了不之後時間久了又忘記,因此寫下筆記以備往後查看。python
首先咱們看看最普通的狀況:session
# 建立0-10的數據集,每一個batch取個數。 dataset = tf.data.Dataset.range(10).batch(6) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(2): value = sess.run(next_element) print(value)
輸出結果機器學習
[0 1 2 3 4 5] [6 7 8 9]
由結果咱們能夠知道TensorFlow能很好地幫咱們自動處理最後一個batch的數據。學習
可是若是上面for循環次數超過2會怎麼樣呢?也就是說若是 **循環次數*批數量 > 數據集數量** 會怎麼樣?咱們試試看:spa
dataset = tf.data.Dataset.range(10).batch(6) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: >>==for i in range(3):==<< value = sess.run(next_element) print(value)
輸出結果code
[0 1 2 3 4 5] [6 7 8 9] --------------------------------------------------------------------------- OutOfRangeError Traceback (most recent call last) D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args) 1277 try: ... ...省略若干信息... ... OutOfRangeError (see above for traceback): End of sequence [[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]
能夠知道超過範圍了,因此報錯了。element
爲了解決上述問題,repeat方法登場。仍是直接看例子吧:資源
dataset = tf.data.Dataset.range(10).batch(6) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(4): value = sess.run(next_element) print(value)
輸出結果get
[0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9]
能夠知道repeat其實就是將數據集重複了指定次數,上面代碼將數據集重複了2次,因此此次即便for循環次數是4也依舊能正常讀取數據,而且都能完整把數據讀取出來。同理,若是把for循環次數設置爲大於4,那麼也仍是會報錯,這麼一來,我每次還得算repeat的次數,豈不是很心累?因此更簡便的辦法就是對repeat方法不設置重複次數,效果見以下:it
dataset = tf.data.Dataset.range(10).batch(6) dataset = dataset.repeat() iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(6): value = sess.run(next_element) print(value)
輸出結果:
[0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9]
此時不管for循環多少次都不怕啦~~
仔細看能夠知道上面全部輸出結果都是有序的,這在機器學習中用來訓練模型是浪費資源且沒有意義的,因此咱們須要將數據打亂,這樣每批次訓練的時候所用到的數據集是不同的,這樣啊能夠提升模型訓練效果。
另外shuffle前須要設置buffer_size:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(4): value = sess.run(next_element) print(value)
輸出結果:
[1 0 2 4 3 5] [7 8 9 6] [1 2 3 4 0 6] [7 8 9 5]
注意:shuffle的順序很重要,通常建議是最開始執行shuffle操做,由於若是是先執行batch操做的話,那麼此時就只是對batch進行shuffle,而batch裏面的數據順序依舊是有序的,那麼隨機程度會減弱。不信你看:
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(4): value = sess.run(next_element) print(value)
輸出結果:
[0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9]