SSD源碼解讀——網絡測試

以前,對SSD的論文進行了解讀,能夠回顧以前的博客:http://www.javashuo.com/article/p-rrxantwu-cm.htmlhtml

爲了加深對SSD的理解,所以對SSD的源碼進行了復現,主要參考的github項目是ssd.pytorch。同時,我本身對該項目增長了大量註釋:https://github.com/Dengshunge/mySSD_pytorchpython

搭建SSD的項目,能夠分紅如下四個部分:git

  1. 數據讀取
  2. 網絡搭建
  3. 損失函數的構建
  4. 網絡測試

接下來,本篇博客重點分析網絡測試github


 

在eval.py文件中,首先須要搭建測試用的網絡。此時,須要將傳入的第一個參數換成"test"字符串,這是由於訓練和測試階段,網絡的輸出會有不一樣。在測試階段,會對預測框進行nms等操做。而後是常規的加載訓練模型,將網絡設置成eval模式,不更新梯度。緩存

    num_classes = len(labelmap) + 1  # +1 for background
    net = build_ssd('test', 300, num_classes)  # initialize SSD
    net.load_state_dict(torch.load(args.trained_model))
    net.eval()

咱們再來看看,測試階段中SSD網絡的不一樣。在ssd.py中,若是是test階段,在類ssd()中,會初始化函數Detect()函數。而且在類SSD()的forward函數中,將座標預測結果,通過softmax的置信度預測結果和先驗錨點框傳遞進去,進行運算。最終輸出一個tensor,shape爲[batch,num_classes,top_k,5]。其中,num_classes是類別總數,對於VOC而言,爲21;top_k表示最多取top_k個錨點框進行輸出,論文中值爲200;5表示[confidence,xmin,ymin,xmax,ymax]。網絡

        if phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(num_classes=self.num_classes, top_k=200,
                                 conf_thresh=0.01, nms_thresh=0.45)
        if self.phase == 'train':
            output = (loc.view(loc.size(0), -1, 4),  # [batch_size,num_priors,4]
                      conf.view(conf.size(0), -1, self.num_classes),  # [batch_size,num_priors,21]
                      self.priors)  # [num_priors,4]
        else:  # Test
            output = self.detect(
                loc.view(loc.size(0), -1, 4),  # 位置預測
                self.softmax(conf.view((conf.size(0), -1, self.num_classes))),  # 置信度預測
                self.priors.cuda()  # 先驗錨點框
            )

在models/detection.py中,定義了類Detect()。首先,建立output來保存最終結果,其shape爲[batch,num_classes,top_k,5],其具體含義能夠看上面。而後對對置信度結果進行transpose轉置,這樣作的目的是方便後續計算,conf_preds的shape爲[batch,num_classes,num_priors]。接着,因爲網絡預測出來的位置預測結果,並非真正的座標,須要對其結果進行解碼,獲得真正的座標(其範圍是[0,1]之間);而後,對每一個類別進行單獨計算(不包含背景), c_mask = conf_scores[cl].gt(self.conf_thresh) 表示對某一類(如bike)獲得shape爲[1,8732]的tensor,每一個值表示預測框對該類別的置信度,經過函數gt(),獲得大於置信度閾值的掩碼(即爲c_mask,元素組成是true或者false)。經過這個mask,就能夠得到大於置信度要求的預測框(包含置信度和座標),並經過nms操做,獲得最終輸出錨點框的Index,將對應的結果(置信度和座標)保存在output中。app

class Detect(Function):
    def __init__(self, num_classes, top_k, conf_thresh, nms_thresh):
        self.num_classes = num_classes
        self.top_k = top_k
        # Parameters used in nms.
        self.nms_thresh = nms_thresh  # 非極大值抑制閾值
        self.conf_thresh = conf_thresh  # 置信度閾值

    def forward(self, loc_data, conf_data, prior_data):
        '''
        :param loc_data: 模型預測的錨點框位置誤差信息,shape[batch,num_priors,4]
        :param conf_data: 模型預測的錨點框置信度,[batch,num_priors,num_classes]
        :param prior_data: 先驗錨點框,[num_priors,4]
        :return:最終預測結果,shape[batch,num_classes,top_k,5],其中5表示[置信度,xmin,ymin,xmax,ymax],
            top_k中前面不爲0的是預測結果,後面爲0是爲了填充
        '''
        num = loc_data.shape[0]  # batch size
        num_priors = prior_data.shape[0]  # 8732
        output = torch.zeros(num, self.num_classes, self.top_k, 5)  # 保存結果
        conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1)  # 置信度預測,transpose是爲了後續操做方便

        for i in range(num):
            decoded_boxes = decode(loc_data[i], prior_data, voc['variance'])  # shape:[num_priors,4],對預測錨點框進行解碼
            # 對每一個類別,執行nms
            conf_scores = conf_preds[i].clone()  # shape:[num_classes,num_priors]
            for cl in range(1, self.num_classes):
                c_mask = conf_scores[cl].gt(self.conf_thresh)  # 和置信度閾值進行比較,大於爲true,不然爲false
                scores = conf_scores[cl][c_mask]  # 獲得置信度大於閾值的那些錨點框置信度
                if scores.shape[0] == 0:
                    # 說明錨點框與這一類的GT框不匹配,簡介說明,不存在這一類的目標
                    continue
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                boxes = decoded_boxes[l_mask].view(-1, 4)  # 獲得置信度大於閾值的那些錨點框
                ids = nms(boxes, scores, self.nms_thresh, self.top_k)  # 對置信度大於閾值的那些錨點框進行nms,獲得最終預測結果的index
                output[i, cl, :len(ids)] = torch.cat((scores[ids].unsqueeze(1),
                                                      boxes[ids]), 1)  # [置信度,xmin,ymin,xmax,ymax]
        return output

解碼函數decode()在models.box_utils.py中。在訓練階段,咱們對座標進行了帶方差的編碼,所以,須要對座標進行一樣方式的解碼。ide

$$b^{cx}=d^w(var[0]*l^{cx})+d^{cx},  b^{cy}=d^h(var[1]*l^{cy})+d^{cy}$$函數

$$b^w=d^wexp(var[2]*l^w),  b^h=d^hexp(var[3]*l^h)$$測試

def decode(loc, priors, variances):
    '''
    對編碼的座標進行解碼,返回預測框的座標
    :param loc: 網絡預測的錨點框誤差信息,shape[num_priors,4]
    :param priors: 先驗錨點框,[num_priors,4]
    :return: 預測框的座標[num_priors,4],4表明[xmin,ymin,xmax,ymax]
    '''
    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * priors[:, 2:] * variances[0],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)  # [中心點x,中心點y,寬,高]
    boxes[:, :2] -= boxes[:, 2:] / 2  # xmin,ymin
    boxes[:, 2:] += boxes[:, :2]  # xmax,ymax
    return boxes

在models.box_utils.py中,還存在着nms()函數。首先對置信度進行降序排序,取出置信度最大的前top_k個用於判斷,其他的錨點框,則不加入判斷中。而後,對這些預測框進行nms操做,即判斷一個錨點框與其他錨點框的IOU,只保留IOU小於閾值的錨點框,排除大於閾值的錨點框,將剩餘的錨點框再次循環,直至idx中不存在元素,即keep中保留的錨點框編號爲最終輸出的錨點框。

def nms(boxes, scores, overlap=0.5, top_k=200):
    '''
    進行nms操做
    :param boxes: 模型預測的錨點框的座標
    :param scores: 模型預測的錨點框對應某一類的置信度
    :param overlap:nms閾值
    :param top_k:選取前top_k個預測框進行操做
    :return:預測框的index
    '''
    keep = torch.zeros(scores.shape[0])
    if boxes.numel() == 0:  # numel()返回tensor裏面全部元素的個數
        return keep
    _, idx = scores.sort(0)  # 升序排序
    idx = idx[-top_k:]  # 取得最大的top_k個置信度對應的index

    keep = []  # 記錄最大最終錨點框的index
    while idx.numel() > 0:
        i = idx[-1]  # 取出置信度最大的錨點框的index
        keep.append(i)
        idx = idx[:-1]
        if idx.numel() == 0:
            break
        IOU = jaccard(boxes[i].unsqueeze(0), boxes[idx])  # 計算這個錨點框與其他錨點框的iou
        mask = IOU.le(overlap).squeeze(0)
        idx = idx[mask]  # 排除大於閾值的錨點框
    return torch.tensor(keep)

接下來,仍是回到eval.py函數中,繼續理解網絡測試代碼。接下來加載測試數據,方式與以前介紹的相似。但這裏的圖片處理方式函數BaseTransform()並不須要進行數據加強,值須要將數據進行resize和減去均值。其他的內容,幾乎一致。

    # 記載數據
    dataset = VOCDetection(args.voc_root, [('2007', 'test')], BaseTransform(300, (104, 117, 123)),
                           VOCAnnotationTransform())

接下來,循環每張測試圖片。根據每類類別,獲得第i張圖片的第j個類別的信息,包含檢測框的置信度和座標,其中,座標是真實座標(不是[0,1]之間),並將其放入變量all_boxes中。變量all_boxes是一個相似二維矩陣的變量,  all_boxes = [[[] for _ in range(num_images)] for _ in range(len(labelmap) + 1)] ,其中,列表示每一個類別,行表示每張圖片的檢測信息。這樣,就能獲得全部圖片全部類別的檢測信息了,就能夠用於下面的準確率、召回率和mAP計算了。

    for i in range(num_images):
        img, gt, h, w = dataset.pull_item(i)
        img = img.unsqueeze(0)
        if torch.cuda.is_available():
            img = img.cuda()

        detections = net(img)  # 獲得結果,shape[1,21,200,5]
        for j in range(1, detections.shape[1]):  # 循環計算每一個類別
            dets = detections[0, j, :]  # shape[200,5],表示每一個類別中最多200個錨點框,每一個錨點框有5個值[conf,xmin,ymin,xmax,ymax]
            mask = dets[:, 0].gt(0.).expand(5, dets.shape[0]).t()  # 取出置信度大於0的狀況.由於可能會出現實際有值的錨點框少於200個
            dets = torch.masked_select(dets, mask).view(-1, 5)  # 取出這些錨點框
            if dets.shape[0] == 0:
                # 說明該圖片不存在該類別
                continue
            boxes = dets[:, 1:]  # 取出錨點框座標
            # 計算出真實座標
            boxes[:, 0] *= w
            boxes[:, 1] *= h
            boxes[:, 2] *= w
            boxes[:, 3] *= h

            scores = dets[:, 0].numpy()
            # np.newaxis增長一個新軸
            # 注意[xmin,ymin,xmax,ymax,conf]
            cls_dets = np.hstack((boxes.numpy(), scores[:, np.newaxis])).astype(np.float32, copy=False)

            all_boxes[j][i] = cls_dets

利用上面獲得的all_boxes信息,進入到測試函數的關鍵部分,函數evaluate_detections()。首先,將全部檢測結果已文本的形式保存下來,方便讀取和調用。而後再執行計算評價指標的函數。

def evaluate_detections(all_boxes, dataset):
    write_voc_results_file(all_boxes, dataset)  # 將全部檢測結果寫成文本,保存下來
    do_python_eval(use_07=False)

在函數write_voc_results_file()中,實現的功能就是根據某一類別和測試圖片的index,讀取變量all_boxes中的檢測信息,將其按照[圖片名,置信度,xmin,ymin,xmax,ymax]的形式,寫入文本中。如VOC,咱們會獲得20個文本文件,不一樣文本表示不一樣的類別;同一文本下,包含了全部測試圖片對該類別的檢測結果。

def write_voc_results_file(all_boxes, dataset):
    # 將檢測結果按照每類寫成文本,方便後面讀取結果
    for cls_ind, cls in enumerate(labelmap):
        print('Writing {:s} VOC results file'.format(cls))
        if not os.path.exists(args.save_det_result):
            os.mkdir(args.save_det_result)
        filename = os.path.join(args.save_det_result, 'det_%s.txt' % (cls))
        with open(filename, 'w') as f:
            for im_ind, index in enumerate(dataset.ids):  # dataset.ids:[path,圖片名]
                dets = all_boxes[cls_ind + 1][im_ind]  # 測試的時候,圖片是按這個順序讀取的
                if dets == []:
                    continue
                for k in range(dets.shape[0]):
                    f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                            format(index[1], dets[k, -1],
                                   dets[k, 0] + 1, dets[k, 1] + 1,
                                   dets[k, 2] + 1, dets[k, 3] + 1))

當將信息保存完後,就進入到do_python_eval()函數中。參數use_07表示(true)使用2007的11點計算mAP方式仍是(false)2010年的mAP計算方式。首先按照每一個類別,讀取上述保存檢測結果的文件。進入到關鍵函數voc_eval()中,獲得召回率、準確率和AP值。

def do_python_eval(use_07=True):
    aps = []  # 保存全部類別的AP
    for i, cls in enumerate(labelmap):
        filename = os.path.join(args.save_det_result, 'det_%s.txt' % (cls))  # 讀取這一類別的檢測結果,對應上剛剛保存的結果
        rec, prec, ap = voc_eval(filename,
                                 os.path.join(args.voc_root, 'VOC2007', 'ImageSets', 'Main', 'test.txt'),
                                 cls,
                                 args.cachedir,
                                 ovthresh=0.5,
                                 use_07_metric=use_07)
        aps += [ap]
        print('AP for {} = {:.4f}'.format(cls, ap))

在函數voc_eval()中,首先,會讀取全部測試圖片相關的xml文件,讀取的方式與一開始介紹的數據讀取相似,將全部信息保存在字典recs中,其中,key爲圖片名字,value是xml文件信息。並將該resc信息保存下來,方便之後繼續讀取。由於這裏面都是真實信息,變更相對較少。而後根據某一類別,在字典recs中讀取每一個圖片,將該類別的信息提取出來,構成字典class_recs,其中key爲某一類下的圖片名稱,value爲GT框座標、是否難例和是否已經檢測過。上面是處理真實信息,接下來,處理預測信息。讀取某一類的預測結果文件,該文件在函數write_voc_results_file()中造成的。而後對文件內容進行分割,獲得文件名,置信度,預測框等集合,並根據置信度,對3個集合進行降序排列。按順序讀取每一個預測框,計算該預測框與這張圖全部GT框的IOU。當IOU大於閾值且該GT框沒有匹配過期,tp的相應位置置1,不然fp的相應位置置0。由此能夠,tp和fp互斥。上述的tp和fp並非true positive和false positive,須要進行行累加,併除以預測框總數或者GT框總數,才能獲得召回率和準確率。以後經過計算,獲得給類別的AP值。

def voc_eval(detpath,  # 某一類別下檢測結果,每一行由文件名,置信度和檢測座標組成
             imagesetfile,  # 包含全部測試圖片的文件
             classname,  # 須要檢測的類別
             cachedir,  # 緩存GT框的pickle文件
             ovthresh=0.5,  # IOU閾值
             use_07_metric=True):
    '''
    假設檢測結果在detpath.format(classname)下
    假設GT框座標在annopath.format(imagename)
    假設imagesetfile每行僅包含一個文件名
    緩存全部GT框
    '''
    if not os.path.isdir(cachedir):
        os.mkdir(cachedir)
    cachefile = os.path.join(cachedir, 'annots.pkl')
    # 讀取全部檢測圖片
    with open(imagesetfile, 'r') as f:
        lines = f.readlines()
    imagenames = [x.strip() for x in lines]  # 每張測試圖片的名字

    # 下面代碼是建立緩存文件,方便讀取
    if not os.path.isfile(cachefile):
        # 不存在GT框緩存文件,則建立
        recs = {}  # key爲圖片名字,value爲該圖片下全部檢測信息
        for i, imagename in enumerate(imagenames):
            recs[imagename] = parse_rec(
                os.path.join(args.voc_root, 'VOC2007', 'Annotations',
                             '%s.xml') % (imagename))  # 返回該圖片下全部xml信息,包含全部目標
            if i % 100 == 0:
                print('Reading annotation for {:d}/{:d}'.format(
                    i + 1, len(imagenames)))
        # 保存下來,方便下次讀取
        print('Saving cached annotations to {:s}'.format(cachefile))
        with open(cachefile, 'wb') as f:
            pickle.dump(recs, f)
    else:
        # 若是已經存在該文件,則加載回來便可
        with open(cachefile, 'rb') as f:
            recs = pickle.load(f)

    # 爲這一類提取GT框
    class_recs = {}
    npos = 0  # 這一類別的gt框總數
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj['name'] == classname]  # 提取某張測試圖片下該類別的信息
        bbox = np.array([x['bbox'] for x in R])  # GT框座標
        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)  # 元素爲true或者false,true表示難例
        det = [False] * len(R)  # 長度爲len(R)的list,用於表示該GT框是否已經匹配過,len(R)能夠理解爲該測試圖片下,該類別的數量
        npos = npos + sum(~difficult)  # 只選取非難例,計算非難例的個數,能夠理解爲GT框的個數
        class_recs[imagename] = {'bbox': bbox,
                                 'difficult': difficult,
                                 'det': det}

    # 讀取這一類的檢測結果
    with open(detpath, 'r') as f:
        lines = f.readlines()

    if any(lines) == 1:  # 不爲空
        splitlines = [x.strip().split(' ') for x in lines]
        image_ids = [x[0] for x in splitlines]  # 圖片名稱集合,包含重複的
        confidence = np.array([float(x[1]) for x in splitlines])  # 置信度集合
        BB = np.array([[float(z) for z in x[2:]] for x in splitlines])  # 檢測框集合

        # 根據置信度,降序排列
        sorted_ind = np.argsort(-confidence)  # 降序排名
        sorted_scores = np.sort(-confidence)  # 降序排列
        BB = BB[sorted_ind, :]  # 檢測框根據置信度進行降序排列
        image_ids = [image_ids[x] for x in sorted_ind]

        nd = len(image_ids)  # 檢測的目標總數
        tp = np.zeros(nd)  # 記錄tp
        fp = np.zeros(nd)  # 記錄fp,與tp互斥

        for d in range(nd):  # 循環每一個預測框
            R = class_recs[image_ids[d]]  # 該圖片下的真實信息
            bb = BB[d, :].astype(float)  # 預測框的座標
            ovmax = -np.inf  # 預測框與GT框的IOU
            BBGT = R['bbox'].astype(float)  # GT框的座標

            if BBGT.size > 0:
                # 計算多個GT框與一個預測框的IOU,選擇最大IOU
                # 下面是計算IOU的流程
                ixmin = np.maximum(BBGT[:, 0], bb[0])
                iymin = np.maximum(BBGT[:, 1], bb[1])
                ixmax = np.minimum(BBGT[:, 2], bb[2])
                iymax = np.minimum(BBGT[:, 3], bb[3])
                iw = np.maximum(ixmax - ixmin, 0.)
                ih = np.maximum(iymax - iymin, 0.)
                inters = iw * ih
                uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) +
                       (BBGT[:, 2] - BBGT[:, 0]) *
                       (BBGT[:, 3] - BBGT[:, 1]) - inters)
                overlaps = inters / uni
                ovmax = np.max(overlaps)  # 獲得該預測框與GT框最大的IOU值
                jmax = np.argmax(overlaps)  # 獲得該預測框對應最大IOU的GT框的index

            if ovmax > ovthresh:
                # 當IOU大於閾值,纔有機會判斷爲正例
                # 判斷爲fp有兩種狀況:
                # 1.該GT框被置信度高的預測框匹配過
                # 2.IOU小於閾值
                if not R['difficult'][jmax]:
                    # 該GT框要求以前沒有匹配過
                    # 因爲置信度是降序排序的,GT框只匹配置信度最高的,其他認爲是FP
                    tp[d] = 1.
                    R['det'][jmax] = 1
                else:
                    fp[d] = 1.
            else:
                fp[d] = 1

        # 計算recall,precision
        fp = np.cumsum(fp)  # shape:[1,nd]
        tp = np.cumsum(tp)  # shape:[1,nd]
        rec = tp / float(npos)  # 召回率
        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)  # 準確率,防止除0
        ap = voc_ap(rec, prec, use_07_metric)
    else:
        rec = -1.
        prec = -1.
        ap = -1.

    return rec, prec, ap

根據召回率,準確率,就能夠計算該類別的AP值了。計算AP值有兩種方法,第一種是2007年的11點計算,該方法給召回率設定11個閾值,如 np.arange(0., 1.1, 0.1) ,計算大於閾值狀況下的最大準確率,得到11個準確率後,求平均,就獲得了AP值;第二種方法是2010年提出的,首先將PR曲線進行平滑,第i-i個點去第i-1個和第i個點的最大值,將PR曲邊變成了遞減曲線,而後計算該遞減曲線下的面積,獲得AP值。

def voc_ap(rec, prec, use_07_metric=True):
    '''
    根據召回率和準確率,計算AP
    AP計算有兩種方式:1.11點計算;2.最大面積計算
    :param rec: [1,num_all_detect]
    :param prec: [1,num_all_detect]
    '''
    if use_07_metric:
        # 舊版,11點計算
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            # 給召回率設定閾值,統計當召回率大於閾值的狀況下,最大的準確率
            if np.sum(rec >= t) == 0:
                # 說明召回率沒有比t更大
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11
    else:
        # 增長兩個數字,是爲了方便計算
        mrec = np.concatenate(([0.], rec, [1.]))  # shape:[1,num_all_detect+2]
        mpre = np.concatenate(([0.], prec, [0.]))

        # 計算最大面積
        for i in range(mpre.size - 1, 0, -1):
            # 取右邊的最大值
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # 獲得與前面的數值不同的index,能夠理解成,計算面積時的邊長
        i = np.where(mrec[1:] != mrec[:-1])[0]
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])  # 計算面積,長*寬
    return ap

至此,SSD的網絡檢測代碼已經解讀完成。

相關文章
相關標籤/搜索