Tensorflow數據讀取機制

展現如何將數據輸入到計算圖中python

Dataset能夠看做是相同類型「元素」的有序列表,在實際使用時,單個元素能夠是向量、字符串、圖片甚至是tuple或dict。數組

數據集對象實例化:dom

dataset=tf.data.Dataset.from_tensor_slice(<data>)

迭代器對象實例化:機器學習

iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()

讀取結束異常:若是一個dataset中的元素被讀取完畢,再嘗試sess.run(one_element)的話,會拋出tf.errors.OutOfRangeError異常,這個行爲與使用隊列方式讀取數據是一致的。函數

高維數據集的使用

tf.data.Dataset.from_tensor_slices真正做用是切分傳入Tensor的第一個維度,生成相應的dataset,即第一維代表數據集中數據的數量,以後切分batch等操做均以第一維爲基礎。學習

dataset=tf.data.Dataset.from_tensor_slices(np.random.uniform((5,2)))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session(config=config) as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError as e:
        print('end~')

輸出:3d

[0.1,0.2]
[0.3,0.2]
[0.1,0.6]
[0.4,0.3]
[0.5,0.2]

tuple組合數據

dataset=tf.data.Dataset.from_tensor_slices((np.array([1.,2.,3.,4.,5.]),
                                            np.random.uniform(size=(5,2))))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print('end~')

輸出:code

(1.,array(0.1,0.3))
(2.,array(0.2,0.4))
...

數據集處理方法

Dataset支持一類特殊操做:Transformation。一個Dataset經過Transformation變成一個新的Dataset。經常使用的Transformationorm

  • map
  • batch
  • shuffle
  • repeat

其中,對象

  • map和python中的map一致,接受一個函數,Dataset中的每一個元素都會做爲這個函數的輸入,並將函數返回值做爲新的Dataset

    dataset=dataset.map(lambda x:x+1)

    注意:map函數可使用num_parallel_calls參數並行化

  • batch就是將多個元素組成batch。

    dataset=tf.data.Dataset.from_tensor_slices(
    {
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.batch(2)  # batch_size=2
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(one_element)
        except tf.errors.OutOfRangeError:
            print('end~')

    輸出:

    {'a':array([1.,2.]),'b':array([[1.,2.],[3.,4.]])}
    {'a':array([3.,4.]),'b':array([[5.,6.],[7.,8.]])}
  • shuffle的功能是打亂dataset中的元素,它有個參數buffer_size,表示打亂時使用的buffer的大小,不該設置太小,推薦值1000.

    dataset=tf.data.Dataset.from_tensor_slices(
    {
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.shuffle(buffer_size=5)
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(one_element)
        except tf.errors.OutOfRangeError:
            print('end~')
  • repeat的功能就是將整個序列重複屢次,主要用來處理機器學習中的epoch。假設原先的數據是一個epoch,使用repeat(2)可使之變成2個epoch.

    dataset=tf.data.Dataset.from_tensor_slices({
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.repeat(2)  # 2epoch
    ###
    # iterator, one_element...

    注意:若是直接調用repeat()函數的話,生成的序列會無限重複下去,沒有結果,所以不會拋出tf.errors.OutOfRangeError異常。

模擬讀入磁盤圖片及其Label示例

def _parse_function(filename,label):  # 接受單個元素,轉換爲目標
    img_string=tf.read_file(filename)
    img_decoded=tf.image.decode_images(img_string)
    img_resized=tf.image.resize_images(image_decoded,[28,28])
    return image_resized,label

filenames=tf.constant(['data/img1.jpg','data/img2.jpg',...])
labels=tf.constant([1,3,...])
dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(_parse_function)  # num_parallel_calls 並行
dataset=dataset.shuffle(buffer_size=1000).batch_size(32).repeat(10)

更多Dataset建立方法

  • tf.data.TextLineDataset():函數輸入一個文件列表,輸出一個Dataset。dataset中的每個元素對應文件中的一行,可使用該方法讀入csv文件。
  • tf.data.FixedLengthRecordDataset():函數輸入一個文件列表和record_bytes參數,dataset中每個元素是文件中固定字節數record_bytes的內容,可用來讀取二進制保存的文件,如CIFAR10。
  • tf.data.TFRecordDataset():讀取TFRecord文件,dataset中每個元素是一個TFExample。

更多Iterator建立方法

最簡單的建立Iterator方法是經過dataset.make_one_shot_iterator()建立一個iterator。

除了這種iterator以外,還有更復雜的Iterator:

  • initializable iterator
  • reinitializable iterator
  • feedable iterator

其中,initializable iterator方法要在使用前經過sess.run()進行初始化,initializable iterator還可用於讀入較大數組。在使用tf.data.Dataset.from_tensor_slices(array)時,實際上發生的事情是將array做爲一個tf.constants保存到了計算圖中,當array很大時,會致使計算圖變得很大,給傳輸保存帶來不便,這時可使用一個placeholder取代這裏的array,並使用initializable iterator,只在須要時將array傳進去,這樣便可避免將大數組保存在圖裏。

features_placeholder=tf.placeholder(<features.dtype>,<features.shape>)
labels_placeholder=tf.placeholder(<labels.dtype>,<labels.shape>)
dataset=tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})

Tensorflow內部讀取機制

對於文件名隊列,使用tf.train.string_input_producer()函數,tf.train.string_input_producer()還有兩個重要參數,num_epochesshuffle

內存隊列不須要咱們創建,只須要使用reader對象從文件名隊列中讀取數據便可,使用tf.train.start_queue_runners()函數啓動隊列,填充兩個隊列的數據。

with tf.Session() as sess:
    filenames=['A.jpg','B.jpg','C.jpg']
    filename_queue=tf.train.string_input_producer(filenames,shuffle=True,num_epoch=5)
    reader=tf.WholeFileReader()
    key,value=reader.read(filename_queue)
    # tf.train.string_input_producer()定義了一個epoch變量,須要對其進行初始化
    tf.local_variables_initializer().run()
    threads=tf.train.start_queue_runners(sess=sess)
    i=0
    while True:
        i+=1
        image_data=sess.run(value)
        with open('reader/test_%d.jpg'%i,'wb') as f:
            f.write(image_data)
相關文章
相關標籤/搜索