tensorflow 批次讀取文件內的數據,並將順序隨機化處理. --[python]

使用tensorflow批次的讀取預處理以後的文本數據,並將其分爲一個迭代器批次:python

好比此刻,我有一個處理以後的數據包: data.csv  shape =(8,10),其中這個結構中,前五個列爲feature , 後五列爲labelspa

1,2,3,4,5,6,7,8,9,10
11,12,13,14,15,16,17,18,19,20
21,22,23,24,25,26,27,28,29,30
31,32,33,34,35,36,37,38,39,40
41,42,43,44,45,46,47,48,49,50
51,52,53,54,55,56,57,58,59,60
1,1,1,1,1,2,2,2,2,2
3,3,3,3,3,4,4,4,4,4

如今我須要將其分爲4個批次: 也就是每一個批次batch的大小爲2code

而後我可能須要將其順序打亂,因此這裏提供了兩種方式,順序和隨機blog

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'xijun1'
import tensorflow as tf
import numpy as np

# data = np.arange(1, 100 + 1)
# print ",".join( [str(i) for i in data])
# data_input = tf.constant(data)
filename_queue = tf.train.string_input_producer(["data.csv"])
reader = tf.TextLineReader(skip_header_lines=0)
key, value = reader.read(filename_queue)
# decode_csv will convert a Tensor from type string (the text line) in
# a tuple of tensor columns with the specified defaults, which also
# sets the data type for each column
words_size = 5  # 每一行數據的長度
decoded = tf.decode_csv(
    value,
    field_delim=',',
    record_defaults=[[0] for i in range(words_size * 2)])

batch_size = 2 # 每個批次的大小
# 隨機
batch_shuffle = tf.train.shuffle_batch(decoded, batch_size=batch_size,
                                       capacity=batch_size * words_size,
                                       min_after_dequeue=batch_size)
#順序
batch_no_shuffle = tf.train.batch(decoded, batch_size=batch_size, capacity=batch_size * words_size,
                                  allow_smaller_final_batch=batch_size)
shuffle_features = tf.transpose(tf.stack(batch_shuffle[0:words_size]))
shuffle_label = tf.transpose(tf.stack(batch_shuffle[words_size:]))
features = tf.transpose(tf.stack(batch_no_shuffle[0:words_size]))
label = tf.transpose(tf.stack(batch_no_shuffle[words_size:]))

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(8/batch_size):
        print (i+10, sess.run([shuffle_features, shuffle_label]))
        print (i, sess.run([features, label]))
    coord.request_stop()
    coord.join(threads)

當咱們運行的時候,咱們能夠獲得這個結果:ip

(10, [array([[ 1,  2,  3,  4,  5],
       [31, 32, 33, 34, 35]], dtype=int32), array([[ 6,  7,  8,  9, 10],
       [36, 37, 38, 39, 40]], dtype=int32)])
(0, [array([[11, 12, 13, 14, 15],
       [21, 22, 23, 24, 25]], dtype=int32), array([[16, 17, 18, 19, 20],
       [26, 27, 28, 29, 30]], dtype=int32)])
(11, [array([[51, 52, 53, 54, 55],
       [ 3,  3,  3,  3,  3]], dtype=int32), array([[56, 57, 58, 59, 60],
       [ 4,  4,  4,  4,  4]], dtype=int32)])
(1, [array([[41, 42, 43, 44, 45],
       [ 1,  1,  1,  1,  1]], dtype=int32), array([[46, 47, 48, 49, 50],
       [ 2,  2,  2,  2,  2]], dtype=int32)])
(12, [array([[ 3,  3,  3,  3,  3],
       [11, 12, 13, 14, 15]], dtype=int32), array([[ 4,  4,  4,  4,  4],
       [16, 17, 18, 19, 20]], dtype=int32)])
(2, [array([[ 1,  2,  3,  4,  5],
       [21, 22, 23, 24, 25]], dtype=int32), array([[ 6,  7,  8,  9, 10],
       [26, 27, 28, 29, 30]], dtype=int32)])
(13, [array([[31, 32, 33, 34, 35],
       [ 1,  1,  1,  1,  1]], dtype=int32), array([[36, 37, 38, 39, 40],
       [ 2,  2,  2,  2,  2]], dtype=int32)])
(3, [array([[41, 42, 43, 44, 45],
       [ 1,  1,  1,  1,  1]], dtype=int32), array([[46, 47, 48, 49, 50],
       [ 2,  2,  2,  2,  2]], dtype=int32)])
相關文章
相關標籤/搜索