Tensorflow datasets.shuffle repeat batch方法

機器學習中數據讀取是很重要的一個環節,TensorFlow也提供了不少實用的方法,爲了不之後時間久了又忘記,因此寫下筆記以備往後查看。python

最普通的正常狀況

首先咱們看看最普通的狀況:session

# 建立0-10的數據集,每一個batch取個數。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

輸出結果機器學習

[0 1 2 3 4 5]
[6 7 8 9]

由結果咱們能夠知道TensorFlow能很好地幫咱們自動處理最後一個batch的數據。學習

datasets.batch(batch_size)與迭代次數的關係

可是若是上面for循環次數超過2會怎麼樣呢?也就是說若是 **循環次數*批數量 > 數據集數量** 會怎麼樣?咱們試試看:spa

dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    >>==for i in range(3):==<<
        value = sess.run(next_element)
        print(value)

輸出結果code

[0 1 2 3 4 5]
[6 7 8 9]
---------------------------------------------------------------------------
OutOfRangeError                           Traceback (most recent call last)
D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1277     try:
   
  ...
  ...省略若干信息...
  ...
  
OutOfRangeError (see above for traceback): End of sequence
     [[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]

能夠知道超過範圍了,因此報錯了。element

datasets.repeat()

爲了解決上述問題,repeat方法登場。仍是直接看例子吧:資源

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

輸出結果get

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

能夠知道repeat其實就是將數據集重複了指定次數,上面代碼將數據集重複了2次,因此此次即便for循環次數是4也依舊能正常讀取數據,而且都能完整把數據讀取出來。同理,若是把for循環次數設置爲大於4,那麼也仍是會報錯,這麼一來,我每次還得算repeat的次數,豈不是很心累?因此更簡便的辦法就是對repeat方法不設置重複次數,效果見以下:it

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(6):
        value = sess.run(next_element)
        print(value)

輸出結果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

此時不管for循環多少次都不怕啦~~

datasets.shuffle(buffer_size)

仔細看能夠知道上面全部輸出結果都是有序的,這在機器學習中用來訓練模型是浪費資源且沒有意義的,因此咱們須要將數據打亂,這樣每批次訓練的時候所用到的數據集是不同的,這樣啊能夠提升模型訓練效果。

另外shuffle前須要設置buffer_size:

  • 不設置會報錯,
  • buffer_size=1:不打亂順序,既保持原序
  • buffer_size越大,打亂程度越大,演示效果見以下代碼:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

輸出結果:

[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]

注意:shuffle的順序很重要,通常建議是最開始執行shuffle操做,由於若是是先執行batch操做的話,那麼此時就只是對batch進行shuffle,而batch裏面的數據順序依舊是有序的,那麼隨機程度會減弱。不信你看:

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

輸出結果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]



MARSGGBO原創




2018-8-5

相關文章
相關標籤/搜索