這裏樓主講解了如何修改Fast RCNN訓練本身的數據集,首先請確保你已經安裝好了Fast RCNN的環境,具體的編配編制操做請參考個人上一篇文章。首先能夠看到fast rcnn的工程目錄下有個Lib目錄
這裏下面存在3個目錄分別是:node
在這裏修改讀寫數據的接口主要是datasets目錄下,fast_rcnn下面主要存放的是python的訓練和測試腳本,以及訓練的配置文件,roi_data_layer下面存放的主要是一些ROI處理操做,utils下面存放的是一些通用操做好比非極大值nms,以及計算bounding box的重疊率等經常使用功能python
可有看到datasets目錄下主要有三個文件,分別是linux
factory.py 學過設計模式的應該知道這是個工廠類,用類生成imdb類而且返回數據庫共網絡訓練和測試使用
imdb.py 這裏是數據庫讀寫類的基類,分裝了許多db的操做,可是具體的一些文件讀寫須要繼承繼續讀寫
pascal_voc.py Ross在這裏用pascal_voc.py這個類來操做git
接下來我來介紹一下pasca_voc.py這個文件,咱們主要是基於這個文件進行修改,裏面有幾個重要的函數須要修改github
在個人檢測任務裏,我主要是從道路卡口數據中檢測車,所以我這裏只有background 和car兩類物體,爲了操做方便,我不像pascal_voc數據集裏面同樣每一個圖像用一個xml來標註多類,先說一下個人數據格式數據庫
這裏我要特別提醒一下你們,必定要注意座標格式,必定要注意座標格式,必定要注意座標格式,重要的事情說三遍!!!,要否則你會範不少錯誤都會是由於座標不一致引發的報錯
windows
這裏是原始的pascal_voc的init函數,在這裏,因爲咱們本身的數據集每每比voc的數據集要更簡單的一些,在做者額代碼裏面用了不少的路徑拼接,咱們不用去迎合他的格式,將這些操做簡單化便可,在這裏我會一一列舉每一個我修改過的函數。這裏按照文件中的順序排列。
原始初始化函數:設計模式
def __init__(self, image_set, year, devkit_path=None): datasets.imdb.__init__(self, 'voc_' + year + '_' + image_set) self._year = year self._image_set = image_set self._devkit_path = self._get_default_path() if devkit_path is None \ else devkit_path self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year) self._classes = ('__background__', # always index 0 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) self._image_ext = '.jpg' self._image_index = self._load_image_set_index() # Default to roidb handler self._roidb_handler = self.selective_search_roidb # PASCAL specific config options self.config = {'cleanup' : True, 'use_salt' : True, 'top_k' : 2000} assert os.path.exists(self._devkit_path), \ 'VOCdevkit path does not exist: {}'.format(self._devkit_path) assert os.path.exists(self._data_path), \ 'Path does not exist: {}'.format(self._data_path)
修改後的初始化函數:緩存
def __init__(self, image_set, devkit_path=None): datasets.imdb.__init__(self, image_set)#imageset 爲train test self._image_set = image_set self._devkit_path = devkit_path self._data_path = os.path.join(self._devkit_path) self._classes = ('__background__','car')#包含的類 self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))#構成字典{'__background__':'0','car':'1'} self._image_index = self._load_image_set_index('ImageList_Version_S_AddData.txt')#添加文件列表 # Default to roidb handler self._roidb_handler = self.selective_search_roidb # PASCAL specific config options self.config = {'cleanup' : True, 'use_salt' : True, 'top_k' : 2000} assert os.path.exists(self._devkit_path), \ 'VOCdevkit path does not exist: {}'.format(self._devkit_path) assert os.path.exists(self._data_path), \ 'Path does not exist: {}'.format(self._data_path)
原始的image_path_from_index:網絡
def image_path_from_index(self, index): """ Construct an image path from the image's "index" identifier. """ image_path = os.path.join(self._data_path, 'JPEGImages', index + self._image_ext) assert os.path.exists(image_path), \ 'Path does not exist: {}'.format(image_path) return image_path
修改後的image_path_from_index:
def image_path_from_index(self, index):#根據_image_index獲取圖像路徑 """ Construct an image path from the image's "index" identifier. """ image_path = os.path.join(self._data_path, index) assert os.path.exists(image_path), \ 'Path does not exist: {}'.format(image_path) return image_path
原始的 _load_image_set_index:
def _load_image_set_index(self): """ Load the indexes listed in this dataset's image set file. """ # Example path to image set file: # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main', self._image_set + '.txt') assert os.path.exists(image_set_file), \ 'Path does not exist: {}'.format(image_set_file) with open(image_set_file) as f: image_index = [x.strip() for x in f.readlines()] return image_index
修改後的 _load_image_set_index:
def _load_image_set_index(self, imagelist):#已經修改 """ Load the indexes listed in this dataset's image set file. """ # Example path to image set file: # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt #/home/chenjie/KakouTrainForFRCNN_1/DataSet/KakouTrainFRCNN_ImageList.txt image_set_file = os.path.join(self._data_path, imagelist)# load ImageList that only contain ImageFileName assert os.path.exists(image_set_file), \ 'Path does not exist: {}'.format(image_set_file) with open(image_set_file) as f: image_index = [x.strip() for x in f.readlines()] return image_index
函數 _get_default_path,我直接刪除了
原始的gt_roidb:
def gt_roidb(self): """ Return the database of ground-truth regions of interest. This function loads/saves from/to a cache file to speed up future calls. """ cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') if os.path.exists(cache_file): with open(cache_file, 'rb') as fid: roidb = cPickle.load(fid) print '{} gt roidb loaded from {}'.format(self.name, cache_file) return roidb gt_roidb = [self._load_pascal_annotation(index) for index in self.image_index] with open(cache_file, 'wb') as fid: cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) print 'wrote gt roidb to {}'.format(cache_file) return gt_roidb
修改後的gt_roidb:
def gt_roidb(self): """ Return the database of ground-truth regions of interest. This function loads/saves from/to a cache file to speed up future calls. """ cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') if os.path.exists(cache_file):#若存在cache file則直接從cache file中讀取 with open(cache_file, 'rb') as fid: roidb = cPickle.load(fid) print '{} gt roidb loaded from {}'.format(self.name, cache_file) return roidb gt_roidb = self._load_annotation() #已經修改,直接讀入整個GT文件 with open(cache_file, 'wb') as fid: cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) print 'wrote gt roidb to {}'.format(cache_file) return gt_roidb
原始的selective_search_roidb(self):
def selective_search_roidb(self): """ Return the database of selective search regions of interest. Ground-truth ROIs are also included. This function loads/saves from/to a cache file to speed up future calls. """ cache_file = os.path.join(self.cache_path, self.name + '_selective_search_roidb.pkl') if os.path.exists(cache_file): with open(cache_file, 'rb') as fid: roidb = cPickle.load(fid) print '{} ss roidb loaded from {}'.format(self.name, cache_file) return roidb if int(self._year) == 2007 or self._image_set != 'test': gt_roidb = self.gt_roidb() ss_roidb = self._load_selective_search_roidb(gt_roidb) roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb) else: roidb = self._load_selective_search_roidb(None) with open(cache_file, 'wb') as fid: cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL) print 'wrote ss roidb to {}'.format(cache_file) return roidb
修改後的selective_search_roidb(self):
這裏有個pkl文件我須要特別說明一下,若是你再次訓練的時候修改了數據庫,好比添加或者刪除了一些樣本,可是你的數據庫名字函數原來那個,好比我這裏訓練的數據庫叫KakouTrain,必需要在data/cache/目錄下把數據庫的緩存文件.pkl給刪除掉,不然其不會從新讀取相應的數據庫,而是直接從以前讀入而後緩存的pkl文件中讀取進來,這樣修改的數據庫並無進入網絡,而是加載了老版本的數據。
def selective_search_roidb(self):#已經修改 """ Return the database of selective search regions of interest. Ground-truth ROIs are also included. This function loads/saves from/to a cache file to speed up future calls. """ cache_file = os.path.join(self.cache_path,self.name + '_selective_search_roidb.pkl') if os.path.exists(cache_file): #若存在cache_file則讀取相對應的.pkl文件 with open(cache_file, 'rb') as fid: roidb = cPickle.load(fid) print '{} ss roidb loaded from {}'.format(self.name, cache_file) return roidb if self._image_set !='KakouTest': gt_roidb = self.gt_roidb() ss_roidb = self._load_selective_search_roidb(gt_roidb) roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb) else: roidb = self._load_selective_search_roidb(None) with open(cache_file, 'wb') as fid: cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL) print 'wrote ss roidb to {}'.format(cache_file) return roidb
原始的_load_selective_search_roidb(self, gt_roidb):
def _load_selective_search_roidb(self, gt_roidb): filename = os.path.abspath(os.path.join(self.cache_path, '..', 'selective_search_data', self.name + '.mat')) assert os.path.exists(filename), \ 'Selective search data not found at: {}'.format(filename) raw_data = sio.loadmat(filename)['boxes'].ravel() box_list = [] for i in xrange(raw_data.shape[0]): box_list.append(raw_data[i][:, (1, 0, 3, 2)] - 1) return self.create_roidb_from_box_list(box_list, gt_roidb)
修改後的_load_selective_search_roidb(self, gt_roidb):
這裏原做者用的是Selective_search,可是我用的是EdgeBox的方法來提取Mat,我沒有修改函數名,只是把輸入的Mat文件給替換了,Edgebox實際的效果比selective_search要好,速度也要更快,具體的EdgeBox代碼你們能夠在Ross的tutorial中看到地址。
注意,這裏很是關鍵!!!!!,因爲Selective_Search中的OP返回的座標順序須要調整,並非左上右下的順序,能夠看到在下面box_list.append()中有一個(1,0,3,2)的操做,無論你用哪一種OP方法,輸入的座標都應該是x1 y1 x2 y2,不要弄成w h 那種格式,也不要調換順序。座標-1,默認座標從0開始,樓主提醒各位,必定要很是注意座標順序,大小,邊界,格式問題,不然你會被錯誤折騰死的!!!
def _load_selective_search_roidb(self, gt_roidb):#已經修改 #filename = os.path.abspath(os.path.join(self.cache_path, '..','selective_search_data',self.name + '.mat')) filename = os.path.join(self._data_path, 'EdgeBox_Version_S_AddData.mat')#這裏輸入相對應的預選框文件路徑 assert os.path.exists(filename), \ 'Selective search data not found at: {}'.format(filename) raw_data = sio.loadmat(filename)['boxes'].ravel() box_list = [] for i in xrange(raw_data.shape[0]): #box_list.append(raw_data[i][:,(1, 0, 3, 2)] - 1)#原來的Psacalvoc調換了列,我這裏box的順序是x1 ,y1,x2,y2 由EdgeBox格式爲x1,y1,w,h通過修改 box_list.append(raw_data[i][:,:] -1) return self.create_roidb_from_box_list(box_list, gt_roidb)
原始的_load_selective_search_IJCV_roidb,我沒用這個數據集,所以不修改這個函數
原始的_load_pascal_annotation(self, index):
def _load_pascal_annotation(self, index): """ Load image and bounding boxes info from XML file in the PASCAL VOC format. """ filename = os.path.join(self._data_path, 'Annotations', index + '.xml') # print 'Loading: {}'.format(filename) def get_data_from_tag(node, tag): return node.getElementsByTagName(tag)[0].childNodes[0].data with open(filename) as f: data = minidom.parseString(f.read()) objs = data.getElementsByTagName('object') num_objs = len(objs) boxes = np.zeros((num_objs, 4), dtype=np.uint16) gt_classes = np.zeros((num_objs), dtype=np.int32) overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) # Load object bounding boxes into a data frame. for ix, obj in enumerate(objs): # Make pixel indexes 0-based x1 = float(get_data_from_tag(obj, 'xmin')) - 1 y1 = float(get_data_from_tag(obj, 'ymin')) - 1 x2 = float(get_data_from_tag(obj, 'xmax')) - 1 y2 = float(get_data_from_tag(obj, 'ymax')) - 1 cls = self._class_to_ind[ str(get_data_from_tag(obj, "name")).lower().strip()] boxes[ix, :] = [x1, y1, x2, y2] gt_classes[ix] = cls overlaps[ix, cls] = 1.0 overlaps = scipy.sparse.csr_matrix(overlaps) return {'boxes' : boxes, 'gt_classes': gt_classes, 'gt_overlaps' : overlaps, 'flipped' : False}
修改後的_load_pascal_annotation(self, index):
def _load_annotation(self): """ Load image and bounding boxes info from annotation format. """ #,此函數做用讀入GT文件,個人文件的格式 CarTrainingDataForFRCNN_1\Images\2015011100035366101A000131.jpg 1 147 65 443 361 gt_roidb = [] annotationfile = os.path.join(self._data_path, 'ImageList_Version_S_GT_AddData.txt') f = open(annotationfile) split_line = f.readline().strip().split() num = 1 while(split_line): num_objs = int(split_line[1]) boxes = np.zeros((num_objs, 4), dtype=np.uint16) gt_classes = np.zeros((num_objs), dtype=np.int32) overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) for i in range(num_objs): x1 = float( split_line[2 + i * 4]) y1 = float (split_line[3 + i * 4]) x2 = float (split_line[4 + i * 4]) y2 = float (split_line[5 + i * 4]) cls = self._class_to_ind['car'] boxes[i,:] = [x1, y1, x2, y2] gt_classes[i] = cls overlaps[i,cls] = 1.0 overlaps = scipy.sparse.csr_matrix(overlaps) gt_roidb.append({'boxes' : boxes, 'gt_classes': gt_classes, 'gt_overlaps' : overlaps, 'flipped' : False}) split_line = f.readline().strip().split() f.close() return gt_roidb
以後的這幾個函數我都沒有修改,檢測結果,我是修改了demo.py這個文件,直接生成txt文件,而後用python opencv直接可視化,沒有用着裏面的接口,感受太麻煩了,先怎麼方便怎麼來
記得在最後的__main__下面也修改相應的路徑
d = datasets.pascal_voc('trainval', '2007')
改爲
d = datasets.kakou('KakouTrain', '/home/chenjie/KakouTrainForFRCNN_1')
而且同時在文件的開頭import 裏面也作修改
import datasets.pascal_voc
改爲
import datasets.kakou
OK,在這裏咱們已經完成了整個的讀取接口的改寫,主要是將GT和預選框Mat文件讀取並返回
當網絡訓練時會調用factory裏面的get方法得到相應的imdb,
首先在文件頭import 把pascal_voc改爲kakou
在這個文件做者生成了多個數據庫的路徑,咱們本身數據庫只要給定根路徑便可,修改主要有如下4個
原始的factory.py:
__sets = {} import datasets.pascal_voc import numpy as np def _selective_search_IJCV_top_k(split, year, top_k): """Return an imdb that uses the top k proposals from the selective search IJCV code. """ imdb = datasets.pascal_voc(split, year) imdb.roidb_handler = imdb.selective_search_IJCV_roidb imdb.config['top_k'] = top_k return imdb # Set up voc_<year>_<split> using selective search "fast" mode for year in ['2007', '2012']: for split in ['train', 'val', 'trainval', 'test']: name = 'voc_{}_{}'.format(year, split) __sets[name] = (lambda split=split, year=year: datasets.pascal_voc(split, year)) # Set up voc_<year>_<split>_top_<k> using selective search "quality" mode # but only returning the first k boxes for top_k in np.arange(1000, 11000, 1000): for year in ['2007', '2012']: for split in ['train', 'val', 'trainval', 'test']: name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k) __sets[name] = (lambda split=split, year=year, top_k=top_k: _selective_search_IJCV_top_k(split, year, top_k)) def get_imdb(name): """Get an imdb (image database) by name.""" if not __sets.has_key(name): raise KeyError('Unknown dataset: {}'.format(name)) return __sets[name]() def list_imdbs(): """List all registered imdbs.""" return __sets.keys()
修改後的factory.py
#import datasets.pascal_voc import datasets.kakou import numpy as np __sets = {} imageset = 'KakouTrain' devkit = '/home/chenjie/DataSet/CarTrainingDataForFRCNN_1/Images_Version_S_AddData' #def _selective_search_IJCV_top_k(split, year, top_k): # """Return an imdb that uses the top k proposals from the selective search # IJCV code. # """ # imdb = datasets.pascal_voc(split, year) # imdb.roidb_handler = imdb.selective_search_IJCV_roidb # imdb.config['top_k'] = top_k # return imdb ### Set up voc_<year>_<split> using selective search "fast" mode ##for year in ['2007', '2012']: ## for split in ['train', 'val', 'trainval', 'test']: ## name = 'voc_{}_{}'.format(year, split) ## __sets[name] = (lambda split=split, year=year: ## datasets.pascal_voc(split, year)) # Set up voc_<year>_<split>_top_<k> using selective search "quality" mode # but only returning the first k boxes ##for top_k in np.arange(1000, 11000, 1000): ## for year in ['2007', '2012']: ## for split in ['train', 'val', 'trainval', 'test']: ## name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k) ## __sets[name] = (lambda split=split, year=year, top_k=top_k: ## _selective_search_IJCV_top_k(split, year, top_k)) def get_imdb(name): """Get an imdb (image database) by name.""" __sets['KakouTrain'] = (lambda imageset = imageset, devkit = devkit: datasets.kakou(imageset,devkit)) if not __sets.has_key(name): raise KeyError('Unknown dataset: {}'.format(name)) return __sets[name]() def list_imdbs(): """List all registered imdbs.""" return __sets.keys()
在這裏終於改完了讀取接口的全部內容,主要步驟是
下面列出一些須要注意的地方
關於下部訓練和檢測網絡,我將在下一篇文章中說明