Tensorflow讀取csv文件(轉)

經常使用的直接讀取方法實例:
#加載包 import tensorflow as tf import os #設置工做目錄 os.chdir("你本身的目錄") #查看目錄 print(os.getcwd()) #讀取函數定義 def read_data(file_queue): reader = tf.TextLineReader(skip_header_lines=1) key, value = reader.read(file_queue) #定義列 defaults = [[0], [0.], [0.], [0.], [0.], ['']] #編碼 Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species = tf.decode_csv(value, defaults) #處理 preprocess_op = tf.case({ tf.equal(Species, tf.constant('Iris-setosa')): lambda: tf.constant(0), tf.equal(Species, tf.constant('Iris-versicolor')): lambda: tf.constant(1), tf.equal(Species, tf.constant('Iris-virginica')): lambda: tf.constant(2), }, lambda: tf.constant(-1), exclusive=True) #棧 return tf.stack([SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm]), preprocess_op def create_pipeline(filename, batch_size, num_epochs=None): file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs) example, label = read_data(file_queue) min_after_dequeue = 1000 capacity = min_after_dequeue + batch_size example_batch, label_batch = tf.train.shuffle_batch( [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue ) return example_batch, label_batch x_train_batch, y_train_batch = create_pipeline('Iris-train.csv', 50, num_epochs=1000) x_test, y_test = create_pipeline('Iris-test.csv', 60) print(x_train_batch,y_train_batch)

 

結果:
Tensor(「shuffle_batch_2:0」, shape=(50, 4), dtype=float32) Tensor(「shuffle_batch_2:1」, shape=(50,), dtype=int32)python

從它的數據維度可知,數據已經讀入。git

一個完整的例子見github:https://github.com/zhangdm/machine-learning-summary/tree/master/tensorflow/tensorflow_iris_nngithub