SSD源碼解讀——損失函數的構建

以前,對SSD的論文進行了解讀,能夠回顧以前的博客:http://www.javashuo.com/article/p-rrxantwu-cm.htmlhtml

爲了加深對SSD的理解,所以對SSD的源碼進行了復現,主要參考的github項目是ssd.pytorch。同時,我本身對該項目增長了大量註釋:https://github.com/Dengshunge/mySSD_pytorchgit

搭建SSD的項目,能夠分紅如下三個部分:github

  1. 數據讀取
  2. 網絡搭建
  3. 損失函數的構建;
  4. 網絡測試

接下來,本篇博客重點分析損失函數的構建網絡


 檢測任務的損失函數,與分類任務的損失函數具備很大不一樣。在檢測的損失函數中,不只須要計類別置信度的差別,座標的差別,還須要使用到各類tricks,例如hard negative mining等。函數

在train.py中,首先須要對損失函數MultiBoxLoss()進行初始化,須要傳入的參數爲num_classes類別數,正例的IOU閾值和hard negative mining的正負樣本比例。在論文中,VOC的類別總數是21(20個類別加上1個背景);當預測框與GT框的IOU大於0.5時,認爲該預測框是正例;hard negative mining的正樣本和負樣本的比例是1:3。學習

    # 損失函數
    criterion = MultiBoxLoss(num_classes=voc['num_classes'],
                             overlap_thresh=0.5,
                             neg_pos=3)

在models/multibox_loss中,定義了損失函數MultiBoxLoss()。在函數forward()中,須要傳進來兩個參數,分別是predictions和targets,其中,predictions是SSD網絡獲得的結果,分別是預測框座標,類別置信度和先驗錨點框;而targets是則是數據讀取中的值,是GT框的座標和類別label。首先,須要建立座標loc_t和類別置信度conf_t的tensor,其shape分別是[batch_size,8732,4]和[batch_size,8732]。而後,使用一個for循環,將GT框與先驗錨點框的座標與label進行match,獲得每一個錨點框的label和座標誤差,並將結果保存與loc_t和conf_t中。因爲制定了某些錨點框用於預測目標,所以,接下來,須要使用這部分錨點框信息來計算損失。取出含目標的錨點框,獲得其index,其中,pos的shape爲[batch_size,8732],每一個元素是true或者false。再從網絡預測的8732個預測框中,取出一樣index的預測框的座標誤差loc_p,而loc_t則是一樣index的先驗錨點框的座標誤差。因爲錨點框對應上了,則使用smooth_l1來計算預測框迴歸的算是loss_l,以下圖所示的$L_{loc}$,圖片來源測試

 接下來,則是使用hard negative mining和計算置信度損失。首先爲模型預測出來的置信度conf_data進行維度變換,由[batch_size,8732,21]變成[batch_size*8732,21]的batch_conf,應該是爲了方便下面進行計算。接下來,計算全部預測框的置信度損失loss_c,將含目標的錨點框(正例)的損失置0,並對損失進行排名,從而選出損失最大的前num_neg個損失的index。將正例的pos_index和損失最大的負例neg_idx提取出來成conf_p,用於參與訓練中,與相同index的先驗錨點框進行計算交叉熵損失計算。最後將置信度損失和位置損失返回。ui

class MultiBoxLoss(nn.Module):
    '''
    SSD損失函數的計算
    '''

    def __init__(self, num_classes, overlap_thresh, neg_pos):
        super(MultiBoxLoss, self).__init__()
        self.num_classes = num_classes  # 類別數
        self.threshold = overlap_thresh  # GT框與先驗錨點框的閾值
        self.negpos_ratio = neg_pos  # 負例的比例

    def forward(self, predictions, targets):
        '''
        對損失函數進行計算:
            1.進行GT框與先驗錨點框的匹配,獲得loc_t和conf_t,分別表示錨點框須要匹配的座標和錨點框須要匹配的label
            2.對包含目標的先驗錨點框loc_t(即正例)與預測的loc_data計算位置損失函數
            3.對負例(即背景)進行損失計算,選擇損失最大的num_neg個負例和正例共同組成訓練樣本,取出這些訓練樣本的錨點框targets_weighted
                與置信度預測值conf_p,計算置信度損失:
                a)爲Hard Negative Mining計算最大置信度loss_c
                b)將loss_c中正例對應的值置0,即保留了全部負例
                c)對此loss_c進行排序,獲得損失最大的idx_rank
                d)計算用於訓練的負例的個數num_neg,約爲正例的3倍
                e)選擇idx_rank中前num_neg個用做訓練
                f)將正例的index和負例的index共同組成用於計算損失的index,並從預測置信度conf_data和真實置信度conf_t提出這些樣本,造成
                    conf_p和targets_weighted,計算二者的置信度損失.
        :param predictions: 一個元祖,包含位置預測,置信度預測,先驗錨點框
                    位置預測:(batch_size,num_priors,4),即[batch_size,8732,4]
                    置信度預測:(batch_size,num_priors,num_classes),即[batch_size, 8732, 21]
                    先驗錨點框:(num_priors,4),即[8732, 4]
        :param targets: 真實框的座標與label,[batch_size,num_objs,5]
                    其中,5表明[xmin,ymin,xmia,ymax,label]
        '''
        loc_data, conf_data, priors = predictions
        num = loc_data.shape[0]  # 即batch_size大小
        priors = priors[:loc_data.shape[1], :]  # 取出8732個錨點框,與位置預測的錨點框數量相同
        num_priors = priors.shape[0]  # 8732

        loc_t = torch.Tensor(num, num_priors, 4)  # [batch_size,8732,4],生成隨機tensor,後續用於填充
        conf_t = torch.Tensor(num, num_priors)  # [batch_size,8732]
        # 取消梯度更新,貌似默認是False
        loc_t.requires_grad = False
        conf_t.requires_grad = False

        for idx in range(num):
            truths = targets[idx][:, :-1]  # 座標值,[xmin,ymin,xmia,ymax]
            labels = targets[idx][:, -1]  # label
            defaults = priors.cuda()
            match(self.threshold, truths, defaults, labels, loc_t, conf_t, idx)
        if torch.cuda.is_available():
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()  # shape:[batch_size,8732],其元素組成是類別標籤號和背景

        pos = conf_t > 0  # 排除label=0,即排除背景,shape[batch_size,8732],其元素組成是true或者false
        # Localization Loss (Smooth L1),定位損失函數
        # Shape: [batch,num_priors,4]
        # pos.dim()表示pos有多少維,應該是一個定值(2)
        # pos由[batch_size,8732]變成[batch_size,8732,1],而後展開成[batch_size,8732,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)  # [num_pos,4],取出帶目標的這些框
        loc_t = loc_t[pos_idx].view(-1, 4)  # [num_pos,4]
        # 位置損失函數
        loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')  # 這裏對損失值是相加,有公式可知,還沒到相除的地步

        # 爲Hard Negative Mining計算max conf across batch
        batch_conf = conf_data.view(-1, self.num_classes)  # shape[batch_size*8732,21]
        # gather函數的做用是沿着定軸dim(1),按照Index(conf_t.view(-1, 1))取出元素
        # batch_conf.gather(1, conf_t.view(-1, 1))的shape[8732,1],做用是獲得每一個錨點框在匹配GT框後的label
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1).long())  # 這個不是最終的置信度損失函數

        # Hard Negative Mining
        # 因爲正例與負例的數據不均衡,所以不是全部負例都用於訓練
        loss_c[pos.view(-1, 1)] = 0  # pos與loss_c維度不同,因此須要轉換一下,選出負例
        loss_c = loss_c.view(num, -1)  # [batch_size,8732]
        _, loss_idx = loss_c.sort(1, descending=True)  # 獲得降序排列的index
        _, idx_rank = loss_idx.sort(1)

        num_pos = pos.sum(1, keepdim=True)  # pos裏面是true或者false,所以sum後的結果應該是包含的目標數量
        num_neg = torch.clamp(self.negpos_ratio * num_pos, max=pos.size(1) - 1)  # 生成一個隨機數用於表示負例的數量,正例和負例的比例約3:1
        neg = idx_rank < num_neg.expand_as(idx_rank)  # [batch_size,8732] 選擇num_neg個負例,其元素組成是true或者false

        # 置信度損失,包括正例和負例
        # [batch_size, 8732, 21],元素組成是true或者false,但true表明着存在目標,其對應的index爲label
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        # pos_idx由true和false組成,表示選擇出來的正例,neg_idx同理
        # (pos_idx + neg_idx)表示選擇出來用於訓練的樣例,包含正例和反例
        # torch.gt(other)函數的做用是逐個元素與other進行大小比較,大於則爲true,不然爲false
        # 所以conf_data[(pos_idx + neg_idx).gt(0)]獲得了全部用於訓練的樣例
        conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos + neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted.long(), reduction='sum')

        # L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        N = num_pos.sum()  # 一個batch裏面全部正例的數量
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c

在hard negative mining中,須要先計算loss_c。從代碼能夠看到  loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1).long()) ,這句代碼就是置信度損失的計算,能夠參考公式進行理解。這裏能夠說起一下,對loss_c的兩次排序,參考這篇博客,首先對值進行降序排序,獲得排名1,而後對排名又進行降序排序,獲得排名2,以下圖所示,即能取出idx_rank的前N個,可得到損失最大那些值,即變量neg的做用。編碼

在計算損失函數時,說起了函數match(),這個函數位於models/box_utils.py中,是一個很是關鍵的函數,對應論文的匹配策略那一章節,其做用是爲每一個錨點框指定GT框和爲每一個GT框指定錨點框。須要傳進來幾個參數,truths是GT框的座標,priors是先驗錨點框的座標[中心點x,中心點y,W,H],labels是GT框對應的類別(不包含背景),loc_t和conf_t是用來保存結果的,idx是第i張圖片。spa

爲了方便表述,num_objects表示一張圖中,GT框的數量;num_priors表示先驗錨點框的數量,即8732。

第一步,因爲先驗錨點框priors的座標形式是[中心點x,中心點y,W,H],須要使用函數point_from()來將其轉化成[x_min,y_min,x_max,y_max]。而後計算每一個GT框與全部先驗錨點框的jaccard值,即IOU的值,使用了numpy風格的計算方式,返回的變量overlaps的shape爲[GT框數量,8732]。

第二步,根據論文,爲每一個GT框匹配一個最大IOU的先驗錨點框,確保每一個GT框至少有一個錨點框進行預測。

第三步,爲每一個錨點框匹配上一個最大IOU的GT框來進行預測。

第四步,變量best_truth_overlap保存着每一個框與GT框的最大IOU值(第三步的結果),使用index_fill()函數,將第二步的結果同步到這個變量中。在index_fill()函數中,使用數值2來進行填充,是爲了確保第二步中獲得的錨點框確定會被選到。對變量best_truth_idx也進行一樣的處理。

第五步,因爲傳入進來的labels的類別是從0開始的,SSD中認爲0應該是背景,因此,須要對labels進行加一。這裏須要注意一下,best_truth_idx的shape是[8732],每一個元素的範圍爲[0,num_objects],因此conf的shape爲[num_priors],每一個元素表示先驗錨點框的label(0是背景)。同時,須要將變量best_truth_overlap中IOU小於閾值(0.5)的錨點框的label設置爲0。並將結果保存與conf_t,返回給外面的函數用於計算。

第六步,一樣須要將GT框的座標進行擴展,造成shape爲[num_priors,4]的matches,這樣每一個錨點框都有對應的座標進行預測,但最終並非每一個錨點框都用於訓練中。

第七步,使用GT框與錨點框進行編碼,對應論文中的公式2,獲得shape爲[num_priors,4]的值,即誤差,將此結果返回出去。

注意,這裏使用的是GT框的信息和先驗錨點框的信息,並無涉及到網絡預測出來的結果。獲得每一個錨點框的類別conf_t和座標loc_t。因爲沒有用到網絡預測的結果,能夠認爲這部分一直都是定值。

def match(threshold, truths, priors, labels, loc_t, conf_t, idx):
    '''
    這個函數對應論文中的matching strategy匹配策略.SSD須要爲每個先驗錨點框都指定一個label,
    這個label或者指向背景,或者指向每一個類別.
    論文中的匹配策略是:
        1.首先,每一個GT框選擇與其IOU最大的一個錨點框,並令這個錨點框的label等於這個GT框的label
        2.而後,當錨點框與GT框的IOU大於閾值(0.5)時,一樣令這個錨點框的label等於這個GT框的label
    所以,代碼上的邏輯爲:
        1.計算每一個GT框與每一個錨點框的IOU,獲得一個shape爲[num_object,num_priors]的矩陣overlaps
        2.選擇與GT框的IOU最大的錨點框,錨點框的index爲best_prior_idx,對應的IOU值爲best_prior_overlap
        3.爲每個錨點框選擇一個IOU最大的GT框,可能會出現多個錨點框匹配一個GT框的狀況,此時,每一個錨點框對應GT框的index爲best_truth_idx,
            對應的IOU爲best_truth_overlap.注意,此時IOU值可能會存在小於閾值的狀況.
        4.第3步可能到致使存在GT框沒有與錨點框匹配上的狀況,因此要和第2步進行結合.在第3步的基礎上,對best_truth_overlap進行選擇,選擇出
            best_prior_idx這些錨點框,讓其對其的IOU等於一個大於1的定值;而且讓best_truth_idx中index爲best_prior_idx的錨點框的label
            與GT框對應上.最終,best_truth_overlap表示每一個錨點框與GT框的最大IOU值,而best_truth_idx表示每一個錨點框用於與相應的GT框進行
            匹配.
        5.第4步中,會存在IOU小於閾值的狀況,要將這些小於IOU閾值的錨點框的label指向背景,完成第二條匹配策略.
            labels表示GT框對應的標籤號,"conf=labels[best_truth_idx]+1"獲得每一個錨點框對應的標籤號,其中label=0是背景.
            "conf[best_truth_overlap < threshold] = 0"則將小於IOU閾值的錨點框的label指向背景
        6.獲得的conf表示每一個錨點框對應的label,還須要一個矩陣,來表示每一個錨點框須要匹配GT框的座標.
            truths表示GT框的座標,"matches = truths[best_truth_idx]"獲得每一個錨點框須要匹配GT框的座標.
    :param threshold:IOU的閾值
    :param truths:GT框的座標,shape:[num_obj,4]
    :param priors:先驗錨點框的座標,shape:[num_priors,4],num_priors=8732
    :param labels:這些GT框對應的label,shape:[num_obj],此時label=0還不是背景
    :param loc_t:座標結果會保存在這個tensor
    :param conf_t:置信度結果會保存在這個tensor
    :param idx:結果保存的idx
    '''
    # 第1步,計算IOU
    overlaps = jaccard(truths, point_from(priors))  # shape:[num_object,num_priors]

    # 第2步,爲每一個真實框匹配一個IOU最大的錨點框,GT框->錨點框
    # best_prior_overlap爲每一個真實框的最大IOU值,shape[num_objects,1]
    # best_prior_idx爲對應的最大IOU的先驗錨點框的Index,其元素值的範圍爲[0,num_priors]
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)

    # 第3步,若先驗錨點框與GT框的IOU>閾值,也將這些錨點框匹配上,錨點框->GT框
    # best_truth_overlap爲每一個先驗錨點框對應其中一個真實框的最大IOU,shape[1,num_priors]
    # best_truth_idx爲每一個先驗錨點框對應的真實框的index,其元素值的範圍爲[0,num_objects]
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)

    best_prior_idx.squeeze_(1)  # [num_objects]
    best_prior_overlap.squeeze_(1)  # [num_objects]
    best_truth_idx.squeeze_(0)  # [num_priors],8732
    best_truth_overlap.squeeze_(0)  # [num_priors],8732

    # 第4步
    # index_fill_(self, dim: _int, index: Tensor, value: Number)對第dim行的index使用value進行填充
    # best_truth_overlap爲第一步匹配的結果,須要使用到,使用best_prior_idx是第二步的結果,也是須要使用上的
    # 因此在best_truth_overlap上進行填充,代表選出來的正例
    # 使用2進行填充,是由於,IOU值的範圍是[0,1],只要使用大於1的值填充,就代表確定能被選出來
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # 肯定最佳先驗錨點框
    # 確保每一個GT框都能匹配上最大IOU的先驗錨點框
    # 獲得每一個先驗錨點框都能有一個匹配上的數字
    # best_prior_idx的元素值的範圍是[0,num_priors],長度爲num_objects
    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j

    # 第5步
    conf = labels[best_truth_idx] + 1  # Shape: [num_priors],0爲背景,因此其他編號+1
    conf[best_truth_overlap < threshold] = 0  # 置信度小於閾值的label設置爲0

    # 第6步
    matches = truths[best_truth_idx]  # 取出最佳匹配的GT框,Shape: [num_priors,4]

    # 進行位置編碼
    loc = encode(matches, priors,voc['variance'])
    loc_t[idx] = loc  # [num_priors,4],應該學習的編碼誤差
    conf_t[idx] = conf  # [num_priors],每一個錨點框的label

 在函數match()中,使用到了函數encode()來對位置進行編碼。參考博客和R-CNN中的公式,假設先驗錨點框的座標爲$(d^{cx},d^{cy},d^w,d^h)$,預測框的座標爲$(b^{cx},b^{cy},b^w,b^h)$,則預測框的轉換值l爲:

$$l^{cx}=(b^{cx}-d^{cx})/d^w,  l^{cy}=(b^{cy}-d^{cy})/d^h$$

$$b^w=d^wexp(l^x),  b^h=d^hexp(l^h)$$

 

 而代碼中,咱們利用了方差的信息,所以進行了相應的調整,總體上是一致的。

def encode(matched, priors, variances):
    '''
    對座標進行編碼,對應論文中的公式2
    利用GT框和先驗錨點框,計算誤差,用於迴歸
    :param matched: 每一個先驗錨點框對應最佳的GT框,Shape: [num_priors, 4],
                    其中4表明[xmin,ymin,xmax,ymax]
    :param priors: 先驗錨點框,Shape: [num_priors,4],
                    其中4表明[中心點x,中心點y,寬,高]
    :return: shape:[num_priors, 4]
    '''
    g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]  # 計算GT框與錨點框中心點的距離
    g_cxcy /= (variances[0] * priors[:, 2:])

    g_wh = (matched[:, 2:] - matched[:, :2])  # xmax-xmin,ymax-ymin
    g_wh /= priors[:, 2:]
    g_wh = torch.log(g_wh) / variances[1]

    return torch.cat([g_cxcy, g_wh], 1)

 


 

至此,SSD的損失函數構建以介紹完成。相比於分類任務,目標檢測的損失函數構建須要更多的代碼,包含了各類tricks。 

相關文章
相關標籤/搜索