tensorflow解析tfrecord

tensorflow解析tfrecord

tensorflow 使用tf.data.TFRecordDataset()讀取tfrecord文件 許多輸入管道都從 TFRecord 格式的文件中提取 tf.train.Example 協議緩衝區消息(例如這種文件使用 tf.python_io.TFRecordWriter 編寫而成)。每一個 tf.train.Example 記錄都包含一個或多個「特徵」,輸入管道一般會將這些特徵轉換爲張量。python

tf.data.TFRecordDataset()

def input_layer():
    filenames = tf.placeholder(tf.string, shape=[None])
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(_parse_example)  # Parse the record into tensors.
    #dataset = dataset.repeat() # Repeat the input indefinitely.
    iterator = dataset.make_initializable_iterator()

    return filenames, iterator,iterator.get_next()
複製代碼

解析 tf.Example 協議緩衝區消息

def _parse_example(example):
    keys_to_feature = {'img_query': tf.FixedLenFeature((), tf.string),
                       'img_positive': tf.FixedLenFeature((), tf.string),
                       'img_negative': tf.FixedLenFeature((), tf.string)
                       }
    feat_tensor_maps = tf.parse_single_example(example, keys_to_feature)

    def _process_img(img_bytes):
        img = tf.image.decode_jpeg(img_bytes)
        img = tf.div(tf.cast(img,tf.float32),255.0)
        return img

    img_query = _process_img(feat_tensor_maps['img_query'])
    img_positive = _process_img(feat_tensor_maps['img_positive'])
    img_negative = _process_img(feat_tensor_maps['img_negative'])

    return img_query, img_positive, img_negative
複製代碼

查看數據

tfrecord_path = '/media/ubuntu/FED8DCB6D8DC6E81/stuff/deep_ranking_tfrecord/train.record'
    filenames_tensor, iterator,ele_tensor = input_layer()
    with tf.Session() as sess:
        sess.run(iterator.initializer,feed_dict={filenames_tensor: [tfrecord_path]})
        img_eval, = sess.run([ele_tensor], feed_dict={filenames_tensor: [tfrecord_path]})
        img_query, img_positive, img_negative=img_eval

        plt.imshow(img_query)
        plt.show()
        plt.imshow(img_positive)
        plt.show()
        plt.imshow(img_negative)
        plt.show()
複製代碼
相關文章
相關標籤/搜索