【TensorFlow】基於ssd_mobilenet模型實現目標檢測

  最近工做的項目使用了TensorFlow中的目標檢測技術,經過訓練本身的樣本集獲得模型來識別遊戲中的物體,在這裏總結下。html

  本文介紹在Windows系統下,使用TensorFlow的object detection API來訓練本身的數據集,所用的模型爲ssd_mobilenet,固然也可使用其餘模型,包括ssd_inception、faster_rcnn、rfcnn_resnet等,其中,ssd模型在各類模型中性能最好,因此便採用它來進行訓練。python

配置環境

  1. 在GitHub上下載所需的models文件,地址:https://github.com/tensorflow/modelsios

  2. 安裝pillow、Jupyter、matplotliblxml,打開anaconda prompt輸入如下命令,並安裝成功git

pip install pillow
pip install jupyter
pip install matplotlib
pip install lxml

  3. 編譯protobuf,object detection API是使用protobuf來訓練模型和配置參數的,因此得先編譯protobuf,下載地址:https://github.com/google/protobuf/releases,具體配置過程可參考:https://blog.csdn.net/dy_guox/article/details/79081499 。github

製做本身的樣本集

  1. 下載labelImg,並標註本身收集的圖片樣本,標註的標籤自動保存爲xml格式,瀏覽器

<annotation>
    <folder>images1</folder>
    <filename>0.png</filename>
    <path>C:\Users\White\Desktop\images1\0.png</path>
    <source>
        <database>Unknown</database>
    </source>
    <size>
        <width>1080</width>
        <height>1920</height>
        <depth>3</depth>
    </size>
    <segmented>0</segmented>
    <object>
        <name>box</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>345</xmin>
            <ymin>673</ymin>
            <xmax>475</xmax>
            <ymax>825</ymax>
        </bndbox>
    </object>
    <object>
        <name>box</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>609</xmin>
            <ymin>1095</ymin>
            <xmax>759</xmax>
            <ymax>1253</ymax>
        </bndbox>
    </object>
</annotation>

  2. 在工程文件夾下新建如下目錄,並將全部的樣本圖片放入images文件夾,將標註保存的xml文件保存到merged_xml文件夾,網絡

  

將樣本數據轉換爲TFRecord格式

  1. 新建train_test_split.py把xml數據集分爲了train 、test、 validation三部分,並存儲在annotations文件夾中,train爲訓練集佔76.5%,test爲測試集10%,validation爲驗證集13.5%,train_test_split.py代碼以下:oracle

import os  
import random  
import time  
import shutil  
  
xmlfilepath=r'merged_xml'  
saveBasePath=r"./annotations"  
  
trainval_percent=0.9  
train_percent=0.85  
total_xml = os.listdir(xmlfilepath)  
num=len(total_xml)  
list=range(num)  
tv=int(num*trainval_percent)  
tr=int(tv*train_percent)  
trainval= random.sample(list,tv)  
train=random.sample(trainval,tr)  
print("train and val size",tv)  
print("train size",tr)  
# print(total_xml[1])  
start = time.time()   
# print(trainval)  
# print(train)  
test_num=0  
val_num=0  
train_num=0  
# for directory in ['train','test',"val"]:  
#         xml_path = os.path.join(os.getcwd(), 'annotations/{}'.format(directory))  
#         if(not os.path.exists(xml_path)):  
#             os.mkdir(xml_path)  
#         # shutil.copyfile(filePath, newfile)  
#         print(xml_path)  
for i  in list:  
    name=total_xml[i]  
            # print(i)  
    if i in trainval:  #train and val set  
    # ftrainval.write(name)  
        if i in train:  
            # ftrain.write(name)  
            # print("train")  
            # print(name)  
            # print("train: "+name+" "+str(train_num))  
            directory="train"  
            train_num+=1  
            xml_path = os.path.join(os.getcwd(), 'annotations/{}'.format(directory))  
            if(not os.path.exists(xml_path)):  
                os.mkdir(xml_path)  
            filePath=os.path.join(xmlfilepath,name)  
            newfile=os.path.join(saveBasePath,os.path.join(directory,name))  
            shutil.copyfile(filePath, newfile)  
  
        else:  
            # fval.write(name)  
            # print("val")  
            # print("val: "+name+" "+str(val_num))  
            directory="validation"  
            xml_path = os.path.join(os.getcwd(), 'annotations/{}'.format(directory))  
            if(not os.path.exists(xml_path)):  
                os.mkdir(xml_path)  
            val_num+=1  
            filePath=os.path.join(xmlfilepath,name)   
            newfile=os.path.join(saveBasePath,os.path.join(directory,name))  
            shutil.copyfile(filePath, newfile)  
            # print(name)  
    else:  #test set  
        # ftest.write(name)  
        # print("test")  
        # print("test: "+name+" "+str(test_num))  
        directory="test"  
        xml_path = os.path.join(os.getcwd(), 'annotations/{}'.format(directory))  
        if(not os.path.exists(xml_path)):  
            os.mkdir(xml_path)  
        test_num+=1  
        filePath=os.path.join(xmlfilepath,name)  
        newfile=os.path.join(saveBasePath,os.path.join(directory,name))  
        shutil.copyfile(filePath, newfile)  
            # print(name)  
  
# End time  
end = time.time()  
seconds=end-start  
print("train total : "+str(train_num))  
print("validation total : "+str(val_num))  
print("test total : "+str(test_num))  
total_num=train_num+val_num+test_num  
print("total number : "+str(total_num))  
print( "Time taken : {0} seconds".format(seconds))  

  2. xml轉換成csv文件,新建xml_to_csv.py,,運行代碼前,須要建一個data目錄,用來放生成的csv文件,結果和代碼以下:app

import os  
import glob  
import pandas as pd  
import xml.etree.ElementTree as ET  
  
  
def xml_to_csv(path):  
    xml_list = []  
    for xml_file in glob.glob(path + '/*.xml'):  
        tree = ET.parse(xml_file)  
        root = tree.getroot()  
        # print(root)  
        print(root.find('filename').text)  
        for member in root.findall('object'):  
            value = (root.find('filename').text,  
                     int(root.find('size')[0].text),   #width  
                     int(root.find('size')[1].text),   #height  
                     member[0].text,  
                     int(member[4][0].text),  
                     int(float(member[4][1].text)),  
                     int(member[4][2].text),  
                     int(member[4][3].text)  
                     )  
            xml_list.append(value)  
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']  
    xml_df = pd.DataFrame(xml_list, columns=column_name)  
    return xml_df  
  
  
def main():  
    for directory in ['train','test','validation']:  
        xml_path = os.path.join(os.getcwd(), 'annotations/{}'.format(directory))  
    # image_path = os.path.join(os.getcwd(), 'merged_xml')  
        xml_df = xml_to_csv(xml_path)  
        # xml_df.to_csv('whsyxt.csv', index=None)  
        xml_df.to_csv('data/whsyxt_{}_labels.csv'.format(directory), index=None)  
        print('Successfully converted xml to csv.')  
  
  
main()  

運行結果以下:dom

在data文件夾下生成的csv文件:

  3. 生成tfrecords文件,python文件名爲generate_tfrecord.py,代碼以下:

  1 from __future__ import division  
  2 from __future__ import print_function  
  3 from __future__ import absolute_import  
  4   
  5 import os  
  6 import io  
  7 import pandas as pd  
  8 import tensorflow as tf  
  9   
 10 from PIL import Image  
 11 from object_detection.utils import dataset_util  
 12 from collections import namedtuple, OrderedDict  
 13   
 14 flags = tf.app.flags  
 15 flags.DEFINE_string('csv_input', '', 'Path to the CSV input')  
 16 flags.DEFINE_string('output_path', '', 'Path to output TFRecord')  
 17 FLAGS = flags.FLAGS  
 18 # TO-DO replace this with label map  
 19 def class_text_to_int(row_label,filename):
 20     if row_label == 'person':
 21         return 1  
 22     elif row_label == 'investigator':
 23         return 2 
 24     elif row_label == 'collector':
 25         return 3
 26     elif row_label == 'wolf':
 27         return 4
 28     elif row_label == 'skull':
 29         return 5
 30     elif row_label == 'inferno':
 31         return 6
 32     elif row_label == 'stone_blame':
 33         return 7
 34     elif row_label == 'green_jelly':
 35         return 8
 36     elif row_label == 'blue_jelly':
 37         return 9
 38     elif row_label == 'box':
 39         return 10
 40     elif row_label == 'golden_box':
 41         return 11
 42     elif row_label == 'silver_box':
 43         return 12
 44     elif row_label == 'jar':
 45         return 13
 46     elif row_label == 'purple_jar':
 47         return 14
 48     elif row_label == 'purple_weapon':
 49         return 15
 50     elif row_label == 'blue_weapon':
 51         return 16
 52     elif row_label == 'blue_shoe':
 53         return 17
 54     elif row_label == 'blue_barde':
 55         return 18
 56     elif row_label == 'blue_ring':
 57         return 19
 58     elif row_label == 'badge':
 59         return 20
 60     elif row_label == 'dragon_stone':
 61         return 21
 62     elif row_label == 'lawn':
 63         return 22
 64     elif row_label == 'mine':
 65         return 23
 66     elif row_label == 'portal':
 67         return 24
 68     elif row_label == 'tower':
 69         return 25
 70     elif row_label == 'hero_stone':
 71         return 26
 72     elif row_label == 'oracle_stone':
 73         return 27
 74     elif row_label == 'arena':
 75         return 28
 76     elif row_label == 'gold_ore':
 77         return 29
 78     elif row_label == 'relic':
 79         return 30
 80     elif row_label == 'ancient':
 81         return 31
 82     elif row_label == 'house':
 83         return 32
 84     else:
 85         print("------------------nonetype:", filename)
 86         None
 87   
 88 def split(df, group):  
 89     data = namedtuple('data', ['filename', 'object'])  
 90     gb = df.groupby(group)  
 91     return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]  
 92   
 93   
 94 def create_tf_example(group, path):  
 95     with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:  
 96         encoded_jpg = fid.read()  
 97     encoded_jpg_io = io.BytesIO(encoded_jpg)  
 98     image = Image.open(encoded_jpg_io)  
 99     width, height = image.size  
100   
101     filename = group.filename.encode('utf8')  
102     image_format = b'png'  
103     xmins = []  
104     xmaxs = []  
105     ymins = []  
106     ymaxs = []  
107     classes_text = []  
108     classes = []  
109   
110     for index, row in group.object.iterrows():  
111         xmins.append(row['xmin'] / width)  
112         xmaxs.append(row['xmax'] / width)  
113         ymins.append(row['ymin'] / height)  
114         ymaxs.append(row['ymax'] / height)  
115         classes_text.append(row['class'].encode('utf8'))  
116         classes.append(class_text_to_int(row['class'], group.filename))
117   
118     tf_example = tf.train.Example(features=tf.train.Features(feature={  
119         'image/height': dataset_util.int64_feature(height),  
120         'image/width': dataset_util.int64_feature(width),  
121         'image/filename': dataset_util.bytes_feature(filename),  
122         'image/source_id': dataset_util.bytes_feature(filename),  
123         'image/encoded': dataset_util.bytes_feature(encoded_jpg),  
124         'image/format': dataset_util.bytes_feature(image_format),  
125         'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),  
126         'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),  
127         'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),  
128         'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),  
129         'image/object/class/text': dataset_util.bytes_list_feature(classes_text),  
130         'image/object/class/label': dataset_util.int64_list_feature(classes),  
131     }))  
132     return tf_example  
133   
134   
135 def main(_):  
136     writer = tf.python_io.TFRecordWriter(FLAGS.output_path)  
137     path = os.path.join(os.getcwd(), 'images')  
138     examples = pd.read_csv(FLAGS.csv_input)  
139     grouped = split(examples, 'filename')  
140     num=0  
141     for group in grouped:  
142         num+=1  
143         tf_example = create_tf_example(group, path)  
144         writer.write(tf_example.SerializeToString())  
145         if(num%100==0):  #每完成100個轉換,打印一次  
146             print(num)  
147   
148     writer.close()  
149     output_path = os.path.join(os.getcwd(), FLAGS.output_path)  
150     print('Successfully created the TFRecords: {}'.format(output_path))  
151   
152   
153 if __name__ == '__main__':  
154     tf.app.run()  

其中,20~83行應改爲在樣本集中標註的類別,我這裏總共有32個類別,字符串row_label應與labelImg中標註的名稱相同。

現將訓練集轉換爲tfrecord格式,輸入以下命令:

python generate_tfrecord.py --csv_input=data/whsyxt_train_labels.csv --output_path=data/whsyxt_train.tfrecord  

相似的,咱們能夠輸入以下命令,將驗證集和測試集也轉換爲tfrecord格式,

python generate_tfrecord.py --csv_input=data/whsyxt_validation_labels.csv --output_path=data/whsyxt_validation.tfrecord 
python generate_tfrecord.py --csv_input=data/whsyxt_test_labels.csv --output_path=data/whsyxt_test.tfrecord

都執行成功後,得到以下文件,

訓練

  1. 在工程文件夾data目錄下建立標籤分類的配置文件(label_map.pbtxt),須要檢測幾種目標,將建立幾個id,代碼以下:

item {
  id: 1 # id從1開始編號
  name: 'person'
}
item {
  id: 2
  name: 'investigator'
}
item {
  id: 3
  name: 'collector'
}
item {
  id: 4
  name: 'wolf'
}
item {
  id: 5
  name: 'skull'
}
item {
  id: 6
  name: 'inferno'
}
item {
  id: 7
  name: 'stone_blame'
}
item {
  id: 8
  name: 'green_jelly'
}
item {
  id: 9
  name: 'blue_jelly'
}
item {
  id: 10
  name: 'box'
}
item {
  id: 11
  name: 'golden_box'
}
item {
  id: 12
  name: 'silver_box'
}
item {
  id: 13
  name: 'jar'
}
item {
  id: 14
  name: 'purple_jar'
}
item {
  id: 15
  name: 'purple_weapon'
}
item {
  id: 16
  name: 'blue_weapon'
}
item {
  id: 17
  name: 'blue_shoe'
}
item {
  id: 18
  name: 'blue_barde'
}
item {
  id: 19
  name: 'blue_ring'
}
item {
  id: 20
  name: 'badge'
}
item {
  id: 21
  name: 'dragon_stone'
}
item {
  id: 22
  name: 'lawn'
}
item {
  id: 23
  name: 'mine'
}
item {
  id: 24
  name: 'portal'
}
item {
  id: 25
  name: 'tower'
}
item {
  id: 26
  name: 'hero_stone'
}
item {
  id: 27
  name: 'oracle_stone'
}
item {
  id: 28
  name: 'arena'
}
item {
  id: 29
  name: 'gold_ore'
}
item {
  id: 30
  name: 'relic'
}
item {
  id: 31
  name: 'ancient'
}
item {
  id: 32
  name: 'house'
}

  2. 配置管道配置文件,找到 models\research\object_detection\samples\configs\ssd_inception_v2_pets.config文件,複製到data文件夾下,修改以後代碼以下:

  1 # SSD with Mobilenet v1, configured for Oxford-IIIT Pets Dataset.
  2 # Users should configure the fine_tune_checkpoint field in the train config as
  3 # well as the label_map_path and input_path fields in the train_input_reader and
  4 # eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
  5 # should be configured.
  6 
  7 model {
  8   ssd {
  9     num_classes: 32
 10     box_coder {
 11       faster_rcnn_box_coder {
 12         y_scale: 10.0
 13         x_scale: 10.0
 14         height_scale: 5.0
 15         width_scale: 5.0
 16       }
 17     }
 18     matcher {
 19       argmax_matcher {
 20         matched_threshold: 0.45
 21         unmatched_threshold: 0.35
 22         ignore_thresholds: false
 23         negatives_lower_than_unmatched: true
 24         force_match_for_each_row: true
 25       }
 26     }
 27     similarity_calculator {
 28       iou_similarity {
 29       }
 30     }
 31     anchor_generator {
 32       ssd_anchor_generator {
 33         num_layers: 6
 34         min_scale: 0.2
 35         max_scale: 0.95
 36         aspect_ratios: 1.0
 37         aspect_ratios: 2.0
 38         aspect_ratios: 0.5
 39         aspect_ratios: 3.0
 40         aspect_ratios: 0.3333
 41       }
 42     }
 43     image_resizer {
 44       fixed_shape_resizer {
 45         height: 300
 46         width: 300
 47       }
 48     }
 49     box_predictor {
 50       convolutional_box_predictor {
 51         min_depth: 0
 52         max_depth: 0
 53         num_layers_before_predictor: 0
 54         use_dropout: false
 55         dropout_keep_probability: 0.8
 56         kernel_size: 1
 57         box_code_size: 4
 58         apply_sigmoid_to_scores: false
 59         conv_hyperparams {
 60           activation: RELU_6,
 61           regularizer {
 62             l2_regularizer {
 63               weight: 0.00004
 64             }
 65           }
 66           initializer {
 67             truncated_normal_initializer {
 68               stddev: 0.03
 69               mean: 0.0
 70             }
 71           }
 72           batch_norm {
 73             train: true,
 74             scale: true,
 75             center: true,
 76             decay: 0.9997,
 77             epsilon: 0.001,
 78           }
 79         }
 80       }
 81     }
 82     feature_extractor {
 83       type: 'ssd_mobilenet_v1'
 84       min_depth: 16
 85       depth_multiplier: 1.0
 86       conv_hyperparams {
 87         activation: RELU_6,
 88         regularizer {
 89           l2_regularizer {
 90             weight: 0.00004
 91           }
 92         }
 93         initializer {
 94           truncated_normal_initializer {
 95             stddev: 0.03
 96             mean: 0.0
 97           }
 98         }
 99         batch_norm {
100           train: true,
101           scale: true,
102           center: true,
103           decay: 0.9997,
104           epsilon: 0.001,
105         }
106       }
107     }
108     loss {
109       classification_loss {
110         weighted_sigmoid {
111         }
112       }
113       localization_loss {
114         weighted_smooth_l1 {
115         }
116       }
117       hard_example_miner {
118         num_hard_examples: 3000
119         iou_threshold: 0.99
120         loss_type: CLASSIFICATION
121         max_negatives_per_positive: 3
122         min_negatives_per_image: 0
123       }
124       classification_weight: 1.0
125       localization_weight: 1.0
126     }
127     normalize_loss_by_num_matches: true
128     post_processing {
129       batch_non_max_suppression {
130         score_threshold: 1e-8
131         iou_threshold: 0.6
132         max_detections_per_class: 100
133         max_total_detections: 100
134       }
135       score_converter: SIGMOID
136     }
137   }
138 }
139 
140 train_config: {
141   batch_size: 24
142   optimizer {
143     rms_prop_optimizer: {
144       learning_rate: {
145         exponential_decay_learning_rate {
146           initial_learning_rate: 0.004
147           decay_steps: 1000
148           decay_factor: 0.95
149         }
150       }
151       momentum_optimizer_value: 0.9
152       decay: 0.9
153       epsilon: 1.0
154     }
155   }
156   #fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
157   from_detection_checkpoint: false
158   # Note: The below line limits the training process to 200K steps, which we
159   # empirically found to be sufficient enough to train the pets dataset. This
160   # effectively bypasses the learning rate schedule (the learning rate will
161   # never decay). Remove the below line to train indefinitely.
162   num_steps: 40000
163   data_augmentation_options {
164     random_horizontal_flip {
165     }
166   }
167   data_augmentation_options {
168     ssd_random_crop {
169     }
170   }
171 }
172 
173 train_input_reader: {
174   tf_record_input_reader {
175     input_path: "E:/Project/object-detection-Game-yellow/data/whsyxt_train.tfrecord"
176   }
177   label_map_path: "E:/Project/object-detection-Game-yellow/data/label_map.pbtxt"
178 }
179 
180 eval_config: {
181   num_examples: 2000
182   # Note: The below line limits the evaluation process to 10 evaluations.
183   # Remove the below line to evaluate indefinitely.
184   max_evals: 10
185 }
186 
187 eval_input_reader: {
188   tf_record_input_reader {
189     input_path: "E:/Project/object-detection-Game-yellow/data/whsyxt_validation.tfrecord"
190   }
191   label_map_path: "E:/Project/object-detection-Game-yellow/data/label_map.pbtxt"
192   shuffle: false
193   num_readers: 1
194 }

這裏須要修改的幾處有:第9行,改成本身標註的總類別數;第175行,改成訓練集tfrecord文件的路徑;第177行和191行,改成本身label_map的路徑;第189行,改成驗證集tfrecord文件的路徑。

咱們能夠在這個管道文件中設置網絡的各類學習參數,如:第141行設置批次大小,第145~148行設置學習率和退化率,第162行設置訓練的總步數等等。

  3. 開始訓練,將object_detection\train.py文件複製到工程目錄下進行訓練便可,命令以下:

python train.py --logtostderr --pipeline_config_path=E:/Project/object-detection-Game-yellow/data/ssd_mobilenet_v1_pets.config --train_dir=E:/Project/object-detection-Game-yellow/data

  無錯誤則開始訓練,等待訓練結束,以下:

使用TensorBoard進行監測

  1.在輸入訓練的命令後,data文件夾下會生成以下文件,該文件存放訓練過程當中的中間數據,並能夠用圖形化的方式展示出來。

  

  2. 新打開一個命令提示符窗口,首先激活TensorFlow,而後輸入以下命令:

tensorboard --logdir==training:your_log_dir --host=127.0.0.1

其中,your_log_dir爲工程目錄中存放訓練結果的文件夾目錄,把目錄地址拷貝出來將其替代。

  3.打開瀏覽器,在地址欄輸入:localhost:6006,便可顯示tensorboard:

導出訓練結果

  1.訓練過程當中將在training目錄下生成一堆model.ckpt-*的文件,以下:

選擇相應步數的模型,使用export_inference_graph.py(其在object detection目錄下)導出pb文件,命令以下:

python export_inference_graph.py --pipeline_config_path=E:\Project\object-detection-Game-yellow\data\ssd_mobilenet_v1_pets.config --trained_checkpoint_prefix ./data/model.ckpt-30000 --output_directory ./data/exported_model_directory

運行命令後,會在工程的data目錄下生成名爲exported_model_directory文件夾,以下:

 

文件夾內容以下:

其中,frozen_inference_graph.pb就是咱們之後將要使用的模型結果。

獲取測試圖片

  1. 新建test_images文件夾和get_testImages.py文件,並加入如下代碼,以下:

 

 1 from PIL import Image
 2 import os.path
 3 import glob
 4 
 5 annotations_test_dir = "E:\\Project\\object-detection-Game-yellow\\annotations\\test\\"
 6 Images_dir = "E:\\Project\\object-detection-Game-yellow\\Images"
 7 test_images_dir = "E:\\Project\\object-detection-Game-yellow\\test_images"
 8 i = 0
 9 for xmlfile in os.listdir(annotations_test_dir):
10     (filepath, tempfilename) = os.path.split(xmlfile)
11     (shotname, extension) = os.path.splitext(tempfilename)
12     xmlname = shotname
13     for pngfile in os.listdir(Images_dir):
14         (filepath, tempfilename) = os.path.split(pngfile)
15         (pngname, extension) = os.path.splitext(tempfilename)
16         if pngname == xmlname:
17              img = Image.open(Images_dir+"\\" + pngname + ".png")
18              img.save(os.path.join(test_images_dir, os.path.basename(pngfile)))
19              print(pngname)
20              i += 1
21 print(i)

 

第五、六、7行,分別是annotations\test文件夾路徑、Images文件夾路徑和test_images文件夾路徑,運行python文件,獲取測試圖片並存儲到test_images文件夾目錄下。

批量保存測試結果

  1. 在工程目錄下新建results文件夾和get_allTestResults.py文件並加入以下代碼,咱們將使用前面訓練出的模型批量測試test_images文件夾中的圖片並保存到results文件夾中,

  1 # -*- coding: utf-8 -*-
  2 import os
  3 from PIL import Image
  4 import time
  5 import tensorflow as tf
  6 from PIL import Image
  7 import numpy as np
  8 import os
  9 import six.moves.urllib as urllib
 10 import sys
 11 import tarfile
 12 import zipfile
 13 import time
 14 
 15 from collections import defaultdict
 16 from io import StringIO
 17 from matplotlib import pyplot as plt
 18 # plt.switch_backend('Agg')
 19 from utils import label_map_util
 20 
 21 from utils import visualization_utils as vis_util
 22 
 23 PATH_TO_TEST_IMAGES = "E:\\Project\\object-detection-Game-yellow\\test_images\\"
 24 MODEL_NAME = 'E:/Project/object-detection-Game-yellow/data'
 25 PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
 26 PATH_TO_LABELS = MODEL_NAME+'/label_map.pbtxt'
 27 NUM_CLASSES = 32
 28 PATH_TO_RESULTS = "E:\\Project\\object-detection-Game-yellow\\results2\\"
 29 
 30 
 31 def load_image_into_numpy_array(image):
 32     (im_width, im_height) = image.size
 33     return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
 34 
 35 
 36 def save_object_detection_result():
 37     IMAGE_SIZE = (12, 8)
 38     # Load a (frozen) Tensorflow model into memory.
 39     detection_graph = tf.Graph()
 40     with detection_graph.as_default():
 41         od_graph_def = tf.GraphDef()
 42         # loading ckpt file to graph
 43         with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
 44             serialized_graph = fid.read()
 45             od_graph_def.ParseFromString(serialized_graph)
 46             tf.import_graph_def(od_graph_def, name='')
 47     # Loading label map
 48     label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
 49     categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
 50                                                                 use_display_name=True)
 51     category_index = label_map_util.create_category_index(categories)
 52     # Helper code
 53     with detection_graph.as_default():
 54         with tf.Session(graph=detection_graph) as sess:
 55             start = time.time()
 56             for test_image in os.listdir(PATH_TO_TEST_IMAGES):
 57                 image = Image.open(PATH_TO_TEST_IMAGES + test_image)
 58                 # the array based representation of the image will be used later in order to prepare the
 59                 # result image with boxes and labels on it.
 60                 image_np = load_image_into_numpy_array(image)
 61                 # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
 62                 image_np_expanded = np.expand_dims(image_np, axis=0)
 63                 image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
 64                 # Each box represents a part of the image where a particular object was detected.
 65                 boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
 66                 # Each score represent how level of confidence for each of the objects.
 67                 # Score is shown on the result image, together with the class label.
 68                 scores = detection_graph.get_tensor_by_name('detection_scores:0')
 69                 classes = detection_graph.get_tensor_by_name('detection_classes:0')
 70                 num_detections = detection_graph.get_tensor_by_name('num_detections:0')
 71                 # Actual detection.
 72                 (boxes, scores, classes, num_detections) = sess.run(
 73                     [boxes, scores, classes, num_detections],
 74                     feed_dict={image_tensor: image_np_expanded})
 75                 # Visualization of the results of a detection.
 76                 vis_util.visualize_boxes_and_labels_on_image_array(
 77                     image_np,
 78                     np.squeeze(boxes),
 79                     np.squeeze(classes).astype(np.int32),
 80                     np.squeeze(scores),
 81                     category_index,
 82                     use_normalized_coordinates=True,
 83                     line_thickness=8)
 84 
 85                 final_score = np.squeeze(scores)
 86                 count = 0
 87                 for i in range(100):
 88                     if scores is None or final_score[i] > 0.5:
 89                         count = count + 1
 90                 print()
 91                 print("the count of objects is: ", count)
 92                 (im_width, im_height) = image.size
 93                 for i in range(count):
 94                     # print(boxes[0][i])
 95                     y_min = boxes[0][i][0] * im_height
 96                     x_min = boxes[0][i][1] * im_width
 97                     y_max = boxes[0][i][2] * im_height
 98                     x_max = boxes[0][i][3] * im_width
 99                     x = int((x_min + x_max) / 2)
100                     y = int((y_min + y_max) / 2)
101                     if category_index[classes[0][i]]['name'] == "tower":
102                         print("this image has a tower!")
103                         y = int((y_max - y_min) / 4 * 3 + y_min)
104                     print("object{0}: {1}".format(i, category_index[classes[0][i]]['name']),
105                           ',Center_X:', x, ',Center_Y:', y)
106                     # print(x_min,y_min,x_max,y_max)
107                 plt.figure(figsize=IMAGE_SIZE)
108                 plt.imshow(image_np)
109                 picName = test_image.split('/')[-1]
110                 # print(picName)
111                 plt.savefig(PATH_TO_RESULTS + picName)
112                 print(test_image + ' succeed')
113 
114             end = time.time()
115             seconds = end - start
116             print("Time taken : {0} seconds".format(seconds))
117 
118 
119 save_object_detection_result()

 

 

最後,咱們就可使用results中的測試結果進行準確率的計算,查看訓練效果及後期優化。

總結

 

轉載請註明出處:http://www.javashuo.com/article/p-kxnwxjax-cr.html

相關文章
相關標籤/搜索