tensorflow 讀取csv文件

1.建立文件列表多線程

path='./data/csvdata/'
file_names=os.listdir(path)
file_list=[os.path.join(path,file_name)for file_name in file_names]

2.讀文件文件列表到文件隊列線程

file_queue=tf.train.string_input_producer(file_list)

3.構建文件閱讀器讀取隊列文件code

reader=tf.TextLineReader()
key,value=reader.read(file_queue)

4.解碼隊列

recodes=[['None'],['None']]
example,label=tf.decoder_csv(value,record_defaults=records)

5.批處理圖片

tf.train.batch([example,label],batch_size=9,num_threads=2,capacity=100)

6.開啓tf會話多線程進行處理ci

coord=tf.train.Coordinator()
threads=tf.start_queue_runners(sess,coord=coord)
coord.request_stop()
coord.join(threads)

完整代碼input

import tensorflow as  tf
import os

'''
 1.構建文件隊列
 2.讀取隊列內容,,默認讀取一個樣本
    1.csv文件,讀取一行
    2.二進制文件,指定一個樣本的bytes讀取
    3.圖片文件,默認讀取一張一張讀取
 3.解碼
 4.批處理讀取文件
 5.主線程取樣本數據訓練
'''


def csv_read(file_list):
    # 1.構造文件隊列
    file_queue = tf.train.string_input_producer(file_list)
    # 2.構造閱讀器,讀取文件
    reader = tf.TextLineReader()
    key, value = reader.read(file_queue)
    # 3.進行文件解碼
    record = [["None"], ["None"]]
    example, label = tf.decode_csv(value, record_defaults=record)
    # 4.批處理
    batch_example, batch_label = tf.train.batch([example, label], batch_size=9, num_threads=1, capacity=9)
    return batch_example, batch_label


if __name__ == '__main__':
    file_names = os.listdir("./data/csvdata")

    file_list = [os.path.join('./data/csvdata', file) for file in file_names]
    # print(file_list)
    example, label = csv_read(file_list)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)
        print(sess.run([example, label]))
        coord.request_stop()
        coord.join(threads)
相關文章
相關標籤/搜索