tensorflow 使用tf.data.TFRecordDataset()讀取tfrecord文件 許多輸入管道都從 TFRecord 格式的文件中提取 tf.train.Example 協議緩衝區消息(例如這種文件使用 tf.python_io.TFRecordWriter 編寫而成)。每一個 tf.train.Example 記錄都包含一個或多個「特徵」,輸入管道一般會將這些特徵轉換爲張量。python
def input_layer():
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_example) # Parse the record into tensors.
#dataset = dataset.repeat() # Repeat the input indefinitely.
iterator = dataset.make_initializable_iterator()
return filenames, iterator,iterator.get_next()
複製代碼
def _parse_example(example):
keys_to_feature = {'img_query': tf.FixedLenFeature((), tf.string),
'img_positive': tf.FixedLenFeature((), tf.string),
'img_negative': tf.FixedLenFeature((), tf.string)
}
feat_tensor_maps = tf.parse_single_example(example, keys_to_feature)
def _process_img(img_bytes):
img = tf.image.decode_jpeg(img_bytes)
img = tf.div(tf.cast(img,tf.float32),255.0)
return img
img_query = _process_img(feat_tensor_maps['img_query'])
img_positive = _process_img(feat_tensor_maps['img_positive'])
img_negative = _process_img(feat_tensor_maps['img_negative'])
return img_query, img_positive, img_negative
複製代碼
tfrecord_path = '/media/ubuntu/FED8DCB6D8DC6E81/stuff/deep_ranking_tfrecord/train.record'
filenames_tensor, iterator,ele_tensor = input_layer()
with tf.Session() as sess:
sess.run(iterator.initializer,feed_dict={filenames_tensor: [tfrecord_path]})
img_eval, = sess.run([ele_tensor], feed_dict={filenames_tensor: [tfrecord_path]})
img_query, img_positive, img_negative=img_eval
plt.imshow(img_query)
plt.show()
plt.imshow(img_positive)
plt.show()
plt.imshow(img_negative)
plt.show()
複製代碼