使用ImageNet在faster-rcnn上訓練本身的分類網絡

具體代碼見https://github.com/zhiyishou/py-faster-rcnnnode


這是我對cup, glasses訓練的識別git

faster-rcnn在fast-rcnn的基礎上加了rpn來將整個訓練都置於GPU內,以用來提升效率,這裏咱們將使用ImageNet的數據集來在faster-rcnn上來訓練本身的分類器。從ImageNet上可下載到不少類別的Image與bounding box annotation來進行訓練(每個類別下的annotation都少於等於image的個數,因此咱們從annotation來創建索引)。github

lib/dataset/factory.py中提供了coco與voc的數據集獲取方法,而咱們要作的就是在這裏加上咱們本身的ImageNet獲取方法,咱們先來創建ImageNet數據獲取主文件。coco與pascal_voc的獲取都是繼承於父類imdb,因此咱們可根據pascal_voc的獲取方法來作模板修改完成咱們的ImageNet類。緩存

建立ImageNet類

因爲在faster-rcnn裏使用rpn來代替了selective_search,因此咱們能夠在使用時直接略過有關selective_search的方法,根據pascal_voc類作模板,咱們須要留下的方法有:網絡

__init__ //初始化
image_path_at //根據數據集列表的index來取圖片絕對地址
image_path_from_index //配合上面
_load_image_set_index //獲取數據集列表
_gt_roidb //獲取ground-truth數據
rpn_roidb //獲取region proposal數據
_load_rpn_roidb //根據gt_roidb生成rpn_roidb數據併合成
_load_psacal_annotation //加載annotation文件並對bounding box進行數據整理

__init__:數據結構

def __init__(self, image_set):
        imdb.__init__(self, 'imagenet')
        self._image_set = image_set
        self._data_path = os.path.join(cfg.DATA_DIR, "imagenet")
        #類別與對應的wnid,能夠修改爲本身要訓練的類別
        self._class_wnids = {
            'cup': 'n03147509',
            'glasses': 'n04272054'
        }

        #類別,修改類別時同時要修改這裏
        self._classes = ('__background__', self._class_wnids['cup'], self._class_wnids['glasses'])
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        #bounding box annotation 文件的目錄
        self._xml_path = os.path.join(self._data_path, "Annotations")
        self._image_ext = '.JPEG'
        #咱們使用xml文件名來作數據集的索引
        # the xml file name and each one corresponding to image file name
        self._image_index = self._load_xml_filenames()
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}

        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)

image_path_atapp

def image_path_at(self, i):
        #使用index來從xml_filenames取到filename,生成絕對路徑
        return self.image_path_from_image_filename(self._image_index[i])

image_path_from_image_filename(相似pascal_voc中的image_path_from_index)dom

def image_path_from_image_filename(self, image_filename):
        image_path = os.path.join(self._data_path, 'Images',
                                  image_filename + self._image_ext)
        assert os.path.exists(image_path), \
                'Path does not exist: {}'.format(image_path)
        return image_path

_load_xml_filenames(相似pascal_voc中的_load_image_set_index)學習

def _load_xml_filenames(self):
        #從Annotations文件夾中拿取到bounding box annotation文件名
        #用來作數據集的索引
        xml_folder_path = os.path.join(self._data_path, "Annotations")
        assert os.path.exists(xml_folder_path), \
            'Path does not exist: {}'.format(xml_folder_path)

        for dirpath, dirnames, filenames in os.walk(xml_folder_path):
                xml_filenames = [xml_filename.split(".")[0] for xml_filename in filenames]

        return xml_filenames

gt_roidbui

def gt_roidb(self):
        #Ground-Truth 數據緩存
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} gt roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        #從xml中獲取Ground-Truth數據
        gt_roidb = [self._load_imagenet_annotation(xml_filename)
                    for xml_filename in self._image_index]
        with open(cache_file, 'wb') as fid:
            cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote gt roidb to {}'.format(cache_file)

        return gt_roidb

rpn_roidb

def rpn_roidb(self):
        #根據gt_roidb生成rpn_roidb,並進行合併           
        gt_roidb = self.gt_roidb()
        rpn_roidb = self._load_rpn_roidb(gt_roidb)
        roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)

        return roidb

_load_rpn_roidb

def _load_rpn_roidb(self, gt_roidb):
        filename = self.config['rpn_file']
        print 'loading {}'.format(filename)
        assert os.path.exists(filename), \
               'rpn data not found at: {}'.format(filename)
        with open(filename, 'rb') as f:
            box_list = cPickle.load(f)
        return self.create_roidb_from_box_list(box_list, gt_roidb)

_load_imagenet_annotation(相似於pascal_voc中的_load_pascal_annotation)

def _load_imagenet_annotation(self, xml_filename):
        #從annotation的xml文件中拿取bounding box數據
        filepath = os.path.join(self._data_path, 'Annotations', xml_filename + '.xml')
        #這裏使用了ap,是我寫的一個annotation parser,在後面貼出代碼
        #它會返回這個xml文件的wnid, 圖像文件名,以及裏面包含的註解物體
        wnid, image_name, objects = ap.parse(filepath)
        num_objs = len(objects)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objects):
            box = obj["box"]
            x1 = box['xmin']
            y1 = box['ymin']
            x2 = box['xmax']
            y2 = box['ymax']
            # 若是這個bounding box並非咱們想要學習的類別,那則跳過
            # go next if the wnid not exist in declared classes
            try:
                cls = self._class_to_ind[obj["wnid"]]
            except KeyError:
                print "wnid %s isn't show in given"%obj["wnid"]
                continue
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes' : boxes,
                'gt_classes': gt_classes,
                'gt_overlaps' : overlaps,
                'flipped' : False,
                'seg_areas' : seg_areas}

annotation_parser.py文件

import os
import xml.dom.minidom

def getText(node):
    return node.firstChild.nodeValue

def getWnid(node):
    return getText(node.getElementsByTagName("name")[0])

def getImageName(node):
    return getText(node.getElementsByTagName("filename")[0])

def getObjects(node):
    objects = []
    for obj in node.getElementsByTagName("object"):
        objects.append({
            "wnid": getText(obj.getElementsByTagName("name")[0]),
            "box":{
                "xmin": int(getText(obj.getElementsByTagName("xmin")[0])),
                "ymin": int(getText(obj.getElementsByTagName("ymin")[0])),
                "xmax": int(getText(obj.getElementsByTagName("xmax")[0])),
                "ymax": int(getText(obj.getElementsByTagName("ymax")[0])),
            }
        })
    return objects

def parse(filepath):
    dom = xml.dom.minidom.parse(filepath)
    root = dom.documentElement
    image_name = getImageName(root)
    wnid = getWnid(root)
    objects = getObjects(root)
    
    return wnid, image_name, objects

則對數據結構的要求是:

|---data
  |---imagenet
    |---Annotations
       |---n03147509
          |---n03147509_*.xml
          |---...
       |---n04272054
          |---n04272054_*.xml
          |---...
    |---Images
       |---n03147508_*.JPEG
       |---...
       |---n04272054_*.JPEG
       |---...

同時我在github上也提供了draw方法,能夠用來將bounding box畫於Image文件上,用來甄別該annotation的正確性

訓練

這樣,咱們的ImageNet類則是生成好了,下面咱們則能夠訓練咱們的數據,可是在開始以前,還有一件事情,那就是修改prototxt中的與類別數目有關的值,我將models/pascal_voc拷貝到了models/imagenet進行修改,好比我想要訓練ZF,若是使用的是train_faster_rcnn_alt_opt.py,則須要修改models/imagenet/ZF/faster_rcnn_alt_opt/下的全部pt文件裏的內容,用以下的法則去替換:

//num爲類別的個數
input-data->num_classes = num
class_score->num_output = num
bbox_pred->num_output   = num*4

我這裏使用train_faster_rcnn_alt_opt.py進行的訓練,這樣的話則須要把添加的models/imagenet做爲可選項

//pt_type 則是添加的選擇項,默認使用psacal_voc的models
./tools/train_faster_rcnn_alt_opt.py --gpu 0 \
--net_name ZF \
--weights data/imagenet_models/ZF.v2.caffemodel[optional] \
--imdb imagenet \
--cfg experiments/cfgs/faster_rcnn_alt_opt.yml \
--pt_type imagenet

識別

這裏咱們則須要使用剛訓練出來的模型進行識別

#就像demo.py同樣,可是使用訓練的models,我建立了tools/classify.py來單獨識別
prototxt = os.path.join(cfg.ROOT_DIR, 'models/imagenet', NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
caffemodel = os.path.join(cfg.ROOT_DIR, 'output/faster_rcnn_alt_opt/imagenet/'+ NETS[args.demo_net][0] +'_faster_rcnn_final.caffemodel')

一樣,在識別前咱們要對識別方法裏的Classes進行修改,修改爲你本身訓練的類別後

執行

./tools/classify.py --net zf

則可對data/demo下的圖片文件使用訓練的zf網絡進行識別

Have fun

相關文章
相關標籤/搜索