【TensorFlow系列】【八】目標檢測之pascal voc數據預處理

本文介紹以下三個方面的知識:html

一、如何將pascal voc數據集轉化爲TensorFlow的tfrecord文件?python

二、如何使用lxml解析xml文件?git

三、如何使用opencv在圖片上畫出目標邊框?github

【第一部分】將pascal voc數據集轉化爲TensorFlow的tfrecord文件session

pascal voc數據集下載地址爲:app

http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html函數

其中的name與label數字標籤的映射關係文件下載地址爲:fetch

https://github.com/tensorflow/modelsui

位於:models-master\research\object_detection\data\pascal_label_map.pbtxtcode

整個解析過程分爲三步:

1.將name與label的映射文件pascal_label_map.pbtxt解析爲字典格式數據,即name---label格式

2.將xml文件使用lxml讀取出來後,將其解析爲字典格式的數據。

3.將原始圖片數據與annotation數據轉爲tfrecord文件格式數據。

代碼以下:

import tensorflow as tf
from lxml import etree
import os
from PIL import Image

#定義解析單個lxml文件
def parse_xml(xml_path,name_label_map):
    tree = etree.parse(xml_path)
    dict = {}
    for x in tree.xpath('//filename'):
        if len(x):
            print("error")
        else:
            dict["image_"+x.tag] = x.text
    for x in tree.xpath('//size'):
        for x1 in x.getchildren():
            dict["image_"+x1.tag] = x1.text
    object_numbers = 0
    #可能存在多個object節點,即多目標檢測
    for obj in tree.xpath('//object'):
        #獲取當前object節點的子節點
        for x in obj.getchildren():
            #判斷節點x是否有子節點
            if len(x):
                if x.tag == 'bndbox':
                    for bbox in x.getchildren():
                        dict['object_'+str(object_numbers)+'/bndbbox/'+bbox.tag] = bbox.text
                else:
                    pass
            elif x.tag == 'name':
                #將name與id均保存到字典中
                dict["object_"+str(object_numbers)+"/"+x.tag] = x.text
                dict["object_" + str(object_numbers) + "/id" ] = name_label_map[x.text]
            else:
                pass
        object_numbers += 1
    dict['object_number'] = object_numbers
    return dict
#將name與label的映射map文件解析爲字典格式
# name<---->id
def parse_map_file(map_file):
    name_label_map = {}
    with open(map_file) as f:
        id = 0
        for line in f.readlines():
            if len(line) > 1:
                if line.find('id') != -1:
                    line = line.strip('\\n')
                    line = line.strip(' ')
                    colon = line.index(':')
                    colon = colon + 1
                    line_id = line[colon:].strip(' ')
                    id = int(line_id)
                elif line.find('name') != -1:
                    line = line.strip('\\n').strip(' ')
                    first = line.index("'")
                    last = line.rindex("'")
                    line_name = line[first+1:last]
                    name_label_map[line_name]=id
                    id = 0
                else:
                    pass
            else:
                #print("empty line")
                pass
    return name_label_map

MAP_FILE = r"D:\models-master\research\object_detection\data\pascal_label_map.pbtxt"
BASE_PATH= r"E:\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\Annotations"
BASE_JPEG_PATH = r"E:\VOCtrainval_11-May-2012\VOCdevkit\VOC2012\JPEGImages"
name_label_map = parse_map_file(MAP_FILE)
xml_file_list = os.listdir(BASE_PATH)
train_list = []
test_list = []
j = 0
for i in range(len(xml_file_list)):
    if j % 6 == 0:
        test_list.append(xml_file_list[i])
    else:
        train_list.append(xml_file_list[i])
    j = j + 1
with tf.python_io.TFRecordWriter(path=r"E:\VOCtrainval_11-May-2012\train.tfrecords") as tf_writer:
    for i in range(len(train_list)):
        file_path = os.path.join(BASE_PATH,train_list[i])
        if os.path.isfile(file_path):
            #解析xml爲字典格式數據
            xml_dict = parse_xml(file_path,name_label_map)
            image = Image.open(os.path.join(BASE_JPEG_PATH,xml_dict['image_filename']))
            image_bytes = image.tobytes()
            features = {}
            features["image"] = tf.train.Feature(bytes_list=tf.train.BytesList(value = [image_bytes]))
            features['image_width'] = tf.train.Feature(int64_list=tf.train.Int64List(value = [int(xml_dict['image_width'])]))
            features['image_height'] = tf.train.Feature(
                int64_list=tf.train.Int64List(value=[int(xml_dict['image_height'])]))
            features['image_depth'] = tf.train.Feature(
                int64_list=tf.train.Int64List(value=[int(xml_dict['image_depth'])]))
            features['image/object_number'] = tf.train.Feature(
                int64_list=tf.train.Int64List(value=[int(xml_dict['object_number'])]))
            xmin = []
            xmax = []
            ymin = []
            ymax = []
            object_id = []
            object_name = []
            object_number = xml_dict['object_number']
            for j in range(object_number):
                object_i = 'object_'+str(j)
                #print(xml_dict[object_i+'/name'])
                #print(type(xml_dict[object_i+'/name']))
                object_name.append(bytes(xml_dict[object_i+'/name'],'utf-8'))
                object_id.append(int(xml_dict[object_i+'/id']))
                xmin.append(float(xml_dict[object_i+'/bndbbox/xmin']))
                xmax.append(float(xml_dict[object_i + '/bndbbox/xmax']))
                ymin.append(float(xml_dict[object_i + '/bndbbox/ymin']))
                ymax.append(float(xml_dict[object_i + '/bndbbox/ymax']))
            #變長數據以list形式存儲
            features["image/object/names"] = tf.train.Feature(bytes_list=tf.train.BytesList(value=object_name))
            features['image/object/id'] = tf.train.Feature(int64_list=tf.train.Int64List(value=object_id))
            features['image/object/xmin'] = tf.train.Feature(float_list=tf.train.FloatList(value=xmin))
            features['image/object/xmax'] = tf.train.Feature(float_list=tf.train.FloatList(value=xmax))
            features['image/object/ymin'] = tf.train.Feature(float_list=tf.train.FloatList(value=ymin))
            features['image/object/ymax'] = tf.train.Feature(float_list=tf.train.FloatList(value=ymax))
            tf_features = tf.train.Features(feature=features)
            tf_example = tf.train.Example(features=tf_features)
            tf_serialized = tf_example.SerializeToString()
            tf_writer.write(tf_serialized)

【第二部分】讀取目標檢測tfrecord數據並使用opencv在圖片上畫出目標邊框

整個過程分爲以下兩步:

1.編寫tfrecord解析函數,即反序列化函數。

2.獲取圖片標註數據,並使用OpenCV繪製邊框。

具體代碼以下:

import tensorflow as tf
import numpy as np
import cv2

def parse_tf(example_proto):
    dics = {}
    #定長數據解析
    dics['image'] = tf.FixedLenFeature(shape=[],dtype=tf.string)
    dics['image_width'] = tf.FixedLenFeature(shape=[], dtype=tf.int64)
    dics['image_height'] = tf.FixedLenFeature(shape=[], dtype=tf.int64)
    dics['image_depth'] = tf.FixedLenFeature(shape=[], dtype=tf.int64)
    dics['image/object_number']= tf.FixedLenFeature(shape=[], dtype=tf.int64)

    #列表數據解析
    dics["image/object/names"] = tf.VarLenFeature(tf.string)
    dics['image/object/id'] = tf.VarLenFeature(tf.int64)
    dics['image/object/xmin'] = tf.VarLenFeature(tf.float32)
    dics['image/object/xmax'] = tf.VarLenFeature(tf.float32)
    dics['image/object/ymin'] = tf.VarLenFeature(tf.float32)
    dics['image/object/ymax'] = tf.VarLenFeature(tf.float32)
    parse_example = tf.parse_single_example(serialized=example_proto,features=dics)
    object_number = parse_example["image/object_number"]
    xmin = parse_example['image/object/xmin']
    xmax = parse_example['image/object/xmax']
    ymin = parse_example['image/object/ymin']
    ymax = parse_example['image/object/ymax']
    image = tf.decode_raw(parse_example['image'],out_type=tf.uint8)
    w = parse_example['image_width']
    h = parse_example['image_height']
    c = parse_example['image_depth']
    return image,w,h,c,object_number,xmin,xmax,ymin,ymax

dataset = tf.data.TFRecordDataset(r"E:\VOCtrainval_11-May-2012\train.tfrecords")
dataset = dataset.map(parse_tf).batch(1).repeat(1)

iterator = dataset.make_one_shot_iterator()

next_element = iterator.get_next()
with tf.Session() as session:
    image, w, h, c, object_number, xmin, xmax, ymin, ymax = session.run(fetches=next_element)
    image = np.reshape(image,newshape=[h[0],w[0],c[0]])
    #使用OpenCV繪製表框
    for i in range(object_number[0]):
        #左上角座標與右下角座標
        cv2.rectangle(image,(xmin.values[i],ymin.values[i]),(xmax.values[i],ymax.values[i]),color=(0,255,0))
    cv2.imshow("s",image)
    cv2.waitKey(0)

效果以下圖:

 

PS:在解析tfrecord數據文件時,因爲在解析函數中拿到的都是tensor,而不是數據自己,又因爲session沒法傳遞到解析函數中,因此許多預處理操做在解析函數中沒法實施,須要在外面拿到數據後,在利用numpy等對數據進行預處理。

相關文章
相關標籤/搜索