tensorflow tfrecoder read write

  1 #  write in tfrecord
  2 import tensorflow as tf
  3 import os
  4 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  5 
  6 
  7 FLAGS = tf.app.flags.FLAGS
  8 tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "驗證碼tfrecords文件")
  9 tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "驗證碼圖片路徑")
 10 tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "驗證碼字符的種類")
 11 
 12 
 13 def dealwithlabel(label_str):
 14 
 15     # 構建字符索引 {0:'A', 1:'B'......}
 16     num_letter = dict(enumerate(list(FLAGS.letter)))
 17 
 18     # 鍵值對反轉 {'A':0, 'B':1......}
 19     letter_num = dict(zip(num_letter.values(), num_letter.keys()))
 20 
 21     print(letter_num)
 22 
 23     # 構建標籤的列表
 24     array = []
 25 
 26     # 給標籤數據進行處理[[b"NZPP"]......]
 27     for string in label_str:
 28 
 29         letter_list = []# [1,2,3,4]
 30 
 31         # 修改編碼,bytes --> string
 32         for letter in string.decode('utf-8'):
 33             letter_list.append(letter_num[letter])
 34 
 35         array.append(letter_list)
 36 
 37     # [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], [16, 6, 13, 10], [1, 0, 8, 17], [0, 9, 24, 14].....]
 38     print(array)
 39 
 40     # 將array轉換成tensor類型
 41     label = tf.constant(array)
 42 
 43     return label
 44 
 45 
 46 def get_captcha_image():
 47     """
 48     獲取驗證碼圖片數據
 49     :param file_list: 路徑+文件名列表
 50     :return: image
 51     """
 52     # 構造文件名
 53     filename = []
 54 
 55     for i in range(6000):
 56         string = str(i) + ".jpg"
 57         filename.append(string)
 58 
 59     # 構造路徑+文件
 60     file_list = [os.path.join(FLAGS.captcha_dir, file) for file in filename]
 61 
 62     # 構造文件隊列
 63     file_queue = tf.train.string_input_producer(file_list, shuffle=False)
 64 
 65     # 構造閱讀器
 66     reader = tf.WholeFileReader()
 67 
 68     # 讀取圖片數據內容
 69     key, value = reader.read(file_queue)
 70 
 71     # 解碼圖片數據
 72     image = tf.image.decode_jpeg(value)
 73 
 74     image.set_shape([20, 80, 3])
 75 
 76     # 批處理數據 [6000, 20, 80, 3]
 77     image_batch = tf.train.batch([image], batch_size=6000, num_threads=1, capacity=6000)
 78 
 79     return image_batch
 80 
 81 
 82 def get_captcha_label():
 83     """
 84     讀取驗證碼圖片標籤數據
 85     :return: label
 86     """
 87     file_queue = tf.train.string_input_producer(["../data/Genpics/labels.csv"], shuffle=False)
 88 
 89     reader = tf.TextLineReader()
 90 
 91     key, value = reader.read(file_queue)
 92 
 93     records = [[1], ["None"]]
 94 
 95     number, label = tf.decode_csv(value, record_defaults=records)
 96 
 97     # [["NZPP"], ["WKHK"], ["ASDY"]]
 98     label_batch = tf.train.batch([label], batch_size=6000, num_threads=1, capacity=6000)
 99 
100     return label_batch
101 
102 
103 def write_to_tfrecords(image_batch, label_batch):
104     """
105     將圖片內容和標籤寫入到tfrecords文件當中
106     :param image_batch: 特徵值
107     :param label_batch: 標籤紙
108     :return: None
109     """
110     # 轉換類型
111     label_batch = tf.cast(label_batch, tf.uint8)
112 
113     print(label_batch)
114 
115     # 創建TFRecords 存儲器
116     writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)
117 
118     # 循環將每個圖片上的數據構造example協議塊,序列化後寫入
119     for i in range(6000):
120         # 取出第i個圖片數據,轉換相應類型,圖片的特徵值要轉換成字符串形式
121         image_string = image_batch[i].eval().tostring()
122 
123         # 標籤值,轉換成整型
124         label_string = label_batch[i].eval().tostring()
125 
126         # 構造協議塊
127         example = tf.train.Example(features=tf.train.Features(feature={
128             "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
129             "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string]))
130         }))
131 
132         writer.write(example.SerializeToString())
133 
134     # 關閉文件
135     writer.close()
136 
137     return None
138 
139 
140 if __name__ == "__main__":
141 
142     # 獲取驗證碼文件當中的圖片
143     image_batch = get_captcha_image()
144 
145     # 獲取驗證碼文件當中的標籤數據
146     label = get_captcha_label()
147 
148     print(image_batch, label)
149 
150     with tf.Session() as sess:
151 
152         coord = tf.train.Coordinator()
153 
154         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
155 
156         # 獲取tensor裏面的值
157         label_str = sess.run(label)
158 
159         print(label_str)
160 
161         # 處理字符串標籤到數字張量
162         label_batch = dealwithlabel(label_str)
163 
164         print(label_batch)
165 
166         # 將圖片數據和內容寫入到tfrecords文件當中
167         write_to_tfrecords(image_batch, label_batch)
168 
169         coord.request_stop()
170 
171         coord.join(threads)
 1 # read tfrecords
 2 def read_and_decode():
 3     """
 4     讀取驗證碼數據API
 5     :return: image_batch, label_batch
 6     """
 7     # 一、構建文件隊列
 8     file_queue = tf.train.string_input_producer([FLAGS.captcha_dir])
 9 
10     # 二、構建閱讀器,讀取文件內容,默認一個樣本
11     reader = tf.TFRecordReader()
12 
13     # 讀取內容
14     key, value = reader.read(file_queue)
15 
16     # tfrecords格式example,須要解析
17     features = tf.parse_single_example(value, features={
18         "image": tf.FixedLenFeature([], tf.string),
19         "label": tf.FixedLenFeature([], tf.string),
20     })
21 
22     # 解碼內容,字符串內容
23     # 一、先解析圖片的特徵值
24     image = tf.decode_raw(features["image"], tf.uint8)
25     # 一、先解析圖片的目標值
26     label = tf.decode_raw(features["label"], tf.uint8)
27 
28     # print(image, label)
29 
30     # 改變形狀
31     image_reshape = tf.reshape(image, [20, 80, 3])
32 
33     label_reshape = tf.reshape(label, [4])
34 
35     print(image_reshape, label_reshape)
36 
37     # 進行批處理,每批次讀取的樣本數 100, 也就是每次訓練時候的樣本
38     image_batch, label_btach = tf.train.batch([image_reshape, label_reshape], batch_size=FLAGS.batch_size, num_threads=1, capacity=FLAGS.batch_size)
39 
40     print(image_batch, label_btach)
41     return image_batch, label_btach

# write flags
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "驗證碼tfrecords文件")
tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "驗證碼圖片路徑")
tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "驗證碼字符的種類")
# read flags
tf.app.flags.DEFINE_string("captcha_dir", "./tfrecords/captcha.tfrecords", "驗證碼數據的路徑")
tf.app.flags.DEFINE_integer("batch_size", 100, "每批次訓練的樣本數")
tf.app.flags.DEFINE_integer("label_num", 4, "每一個樣本的目標值數量")
tf.app.flags.DEFINE_integer("letter_num", 26, "每一個目標值取的字母的可能心個數")
相關文章
相關標籤/搜索