TensorFlow 的 How-Tos,講解了這麼幾點:python
1. 變量:建立,初始化,保存,加載,共享;api
2. TensorFlow 的可視化學習,(r0.12版本後,加入了Embedding Visualization)網絡
3. 數據的讀取;分佈式
4. 線程和隊列;ide
5. 分佈式的TensorFlow;函數
6. 增長新的Ops;post
7. 自定義數據讀取;學習
因爲各類緣由,本人只看了前5個部分,剩下的2個部分還沒來得及看,時間緊任務重,因此匆匆發車了,之後若是有用到的地方,再回過頭來研究。學習過程當中深感官方文檔的繁雜冗餘極多多,特別是第三部分數據讀取,又臭又長,花了我很久時間,因此我想把第三部分整理以下,方便乘客們。ui
TensorFlow 有三種方法讀取數據:1)供給數據,用placeholder;2)從文件讀取;3)用常量或者是變量來預加載數據,適用於數據規模比較小的狀況。供給數據沒什麼好說的,前面已經見過了,不難理解,咱們就簡單的說一下從文件讀取數據。spa
官方的文檔裏,從文件讀取數據是一段很長的描述,連接層出不窮,看完這個連接還沒看幾個字,就出現了下一個連接。
本身花了好久才認識路,因此想把這部分總結一下,帶帶個人乘客們。
首先要知道你要讀取的文件的格式,選擇對應的文件讀取器;
而後,定位到數據文件夾下,用
["file0", "file1"] # or [("file%d" % i) for i in range(2)]) # or tf.train.match_filenames_once
選擇要讀取的文件的名字,用 tf.train.string_input_producer 函數來生成文件名隊列,這個函數能夠設置shuffle = Ture,來打亂隊列,能夠設置epoch = 5,過5遍訓練數據。
最後,選擇的文件讀取器,讀取文件名隊列並解碼,輸入 tf.train.shuffle_batch 函數中,生成 batch 隊列,傳遞給下一層。
1)假如你要讀取的文件是像 CSV 那樣的文本文件,用的文件讀取器和解碼器就是 TextLineReader 和 decode_csv 。
2)假如你要讀取的數據是像 cifar10 那樣的 .bin 格式的二進制文件,就用 tf.FixedLengthRecordReader 和 tf.decode_raw 讀取固定長度的文件讀取器和解碼器。以下列出了個人參考代碼,後面會有詳細的解釋,這邊先大體瞭解一下:
class cifar10_data(object): def __init__(self, filename_queue): self.height = 32 self.width = 32 self.depth = 3 self.label_bytes = 1 self.image_bytes = self.height * self.width * self.depth self.record_bytes = self.label_bytes + self.image_bytes self.label, self.image = self.read_cifar10(filename_queue) def read_cifar10(self, filename_queue): reader = tf.FixedLengthRecordReader(record_bytes = self.record_bytes) key, value = reader.read(filename_queue) record_bytes = tf.decode_raw(value, tf.uint8) label = tf.cast(tf.slice(record_bytes, [0], [self.label_bytes]), tf.int32) image_raw = tf.slice(record_bytes, [self.label_bytes], [self.image_bytes]) image_raw = tf.reshape(image_raw, [self.depth, self.height, self.width]) image = tf.transpose(image_raw, (1,2,0)) image = tf.cast(image, tf.float32) return label, image def inputs(data_dir, batch_size, train = True, name = 'input'): with tf.name_scope(name): if train: filenames = [os.path.join(data_dir,'data_batch_%d.bin' % ii) for ii in range(1,6)] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) filename_queue = tf.train.string_input_producer(filenames) read_input = cifar10_data(filename_queue) images = read_input.image images = tf.image.per_image_whitening(images) labels = read_input.label num_preprocess_threads = 16 image, label = tf.train.shuffle_batch( [images,labels], batch_size = batch_size, num_threads = num_preprocess_threads, min_after_dequeue = 20000, capacity = 20192) return image, tf.reshape(label, [batch_size]) else: filenames = [os.path.join(data_dir,'test_batch.bin')] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) filename_queue = tf.train.string_input_producer(filenames) read_input = cifar10_data(filename_queue) images = read_input.image images = tf.image.per_image_whitening(images) labels = read_input.label num_preprocess_threads = 16 image, label = tf.train.shuffle_batch( [images,labels], batch_size = batch_size, num_threads = num_preprocess_threads, min_after_dequeue = 20000, capacity = 20192) return image, tf.reshape(label, [batch_size])
3)若是你要讀取的數據是圖片,或者是其餘類型的格式,那麼能夠先把數據轉換成 TensorFlow 的標準支持格式 tfrecords ,它實際上是一種二進制文件,經過修改 tf.train.Example 的Features,將 protocol buffer 序列化爲一個字符串,再經過 tf.python_io.TFRecordWriter 將序列化的字符串寫入 tfrecords,而後再用跟上面同樣的方式讀取tfrecords,只是讀取器變成了tf.TFRecordReader,以後經過一個解析器tf.parse_single_example ,而後用解碼器 tf.decode_raw 解碼。
例如,對於生成式對抗網絡GAN,我採用了這個形式進行輸入,部分代碼以下,後面會有詳細解釋,這邊先大體瞭解一下:
def _int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) def _bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) def convert_to(data_path, name): """ Converts s dataset to tfrecords """ rows = 64 cols = 64 depth = DEPTH for ii in range(12): writer = tf.python_io.TFRecordWriter(name + str(ii) + '.tfrecords') for img_name in os.listdir(data_path)[ii*16384 : (ii+1)*16384]: img_path = data_path + img_name img = Image.open(img_path) h, w = img.size[:2] j, k = (h - OUTPUT_SIZE) / 2, (w - OUTPUT_SIZE) / 2 box = (j, k, j + OUTPUT_SIZE, k+ OUTPUT_SIZE) img = img.crop(box = box) img = img.resize((rows,cols)) img_raw = img.tobytes() example = tf.train.Example(features = tf.train.Features(feature = { 'height': _int64_feature(rows), 'weight': _int64_feature(cols), 'depth': _int64_feature(depth), 'image_raw': _bytes_feature(img_raw)})) writer.write(example.SerializeToString()) writer.close() def read_and_decode(filename_queue): """ read and decode tfrecords """ # filename_queue = tf.train.string_input_producer([filename_queue]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example,features = { 'image_raw':tf.FixedLenFeature([], tf.string)}) image = tf.decode_raw(features['image_raw'], tf.uint8) return image
這裏,個人data_path下面有16384*12張圖,經過12次寫入Example操做,把圖片數據轉化成了12個tfrecords,每一個tfrecords裏面有16384張圖。
4)若是想定義本身的讀取數據操做,請參考https://www.tensorflow.org/how_tos/new_data_formats/。
好了,今天的車到站了,請帶好隨身物品準備下車,明天老司機還有一趟車,請記得準時乘坐,車不等人。
參考文獻:
1. https://www.tensorflow.org/how_tos/
2. 沒了