CenterNet之loss計算代碼解析

[GiantPandaCV導語] 本文主要講解CenterNet的loss,由偏置部分(reg loss)、熱圖部分(heatmap loss)、寬高(wh loss)部分三部分loss組成,附代碼實現。python

1. 網絡輸出

論文中提供了三個用於目標檢測的網絡,都是基於編碼解碼的結構構建的。網絡

  1. ResNet18 + upsample + deformable convolution : COCO AP 28%/142FPS
  2. DLA34 + upsample + deformable convolution : COCO AP 37.4%/52FPS
  3. Hourglass104: COCO AP 45.1%/1.4FPS

這三個網絡中輸出內容都是同樣的,80個類別,2個預測中心對應的長和寬,2箇中心點的誤差。ide

# heatmap 輸出的tensor的通道個數是80,每一個通道表明對應類別的heatmap
(hm): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1))
)
# wh 輸出是中心對應的長和寬,通道數爲2
(wh): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
# reg 輸出的tensor通道個數爲2,分別是w,h方向上的偏移量
(reg): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)

2. 損失函數

2.1 heatmap loss

輸入圖像\(I\in R^{W\times H\times 3}\), W爲圖像寬度,H爲圖像高度。網絡輸出的關鍵點熱圖heatmap爲\(\hat{Y}\in [0,1]^{\frac{W}{R}\times \frac{H}{R}\times C}\)其中,R表明獲得輸出相對於原圖的步長stride。C表明類別個數。函數

下面是CenterNet中核心loss公式:編碼

\[L_k=\frac{-1}{N}\sum_{xyc}\begin{cases} (1-\hat{Y}_{xyc})^\alpha log(\hat{Y}_{xyc})& Y_{xyc}=1\\ (1-Y_{xyc})^\beta(\hat{Y}_{xyc})^\alpha log(1-\hat{Y}_{xyc})& otherwise \end{cases} \]

這個和Focal loss形式很類似,\(\alpha\)\(\beta\)是超參數,N表明的是圖像關鍵點個數。spa

  • \(Y_{xyc}=1\)的時候,

對於易分樣原本說,預測值\(\hat{Y}_{xyc}\)接近於1,\((1-\hat{Y}_{xyc})^\alpha\)就是一個很小的值,這樣loss就很小,起到了矯正做用。code

對於難分樣原本說,預測值\(\hat{Y}_{xyc}\)接近於0,$ (1-\hat{Y}_{xyc})^\alpha$就比較大,至關於加大了其訓練的比重。orm

  • otherwise的狀況下:

otherwise分爲兩個狀況A和B

上圖是一個簡單的示意,縱座標是\({Y}_{xyc}\),分爲A區(距離中心點較近,可是值在0-1之間)和B區(距離中心點很遠接近於0)。blog

對於A區來講,因爲其周圍是一個高斯核生成的中心,\(Y_{xyc}\)的值是從1慢慢變到0。ip

舉個例子(CenterNet中默認\(\alpha=2,\beta=4\)):

\(Y_{xyc}=0.8\)的狀況下,

  • 若是\(\hat{Y}_{xyc}=0.99\),那麼loss=\((1-0.8)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。

  • 若是\(\hat{Y}_{xyc}=0.8\), 那麼loss=\((1-0.8)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。

  • 若是\(\hat{Y}_{xyc}=0.5\), 那麼loss=\((1-0.8)^4(0.5)^2log(1-0.5)\),

  • 若是\(\hat{Y}_{xyc}=0.99\),那麼loss=\((1-0.5)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。

  • 若是\(\hat{Y}_{xyc}=0.8\), 那麼loss=\((1-0.5)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。

  • 若是\(\hat{Y}_{xyc}=0.5\), 那麼loss=\((1-0.5)^4(0.5)^2log(1-0.5)\),

總結一下:爲了防止預測值\(\hat{Y}_{xyc}\)太高接近於1,因此用\((\hat{Y}_{xyc})^\alpha\)來懲罰Loss。而\((1-Y_{xyc})^\beta\)這個參數距離中心越近,其值越小,這個權重是用來減輕懲罰力度。

對於B區來講\(\hat{Y}_{xyc}\)的預測值理應是0,若是該值比較大好比爲1,那麼\((\hat{Y}_{xyc})^\alpha\)做爲權重會變大,懲罰力度也加大了。若是預測值接近於0,那麼\((\hat{Y}_{xyc})^\alpha\)會很小,讓其損失比重減少。對於\((1-Y_{xyc})^\beta\)來講,B區的值比較大,弱化了中心點周圍其餘負樣本的損失比重。

2.2 offset loss

因爲三個骨幹網絡輸出的feature map的空間分辨率變爲原來輸入圖像的四分之一。至關於輸出feature map上一個像素點對應原始圖像的4x4的區域,這會帶來較大的偏差,所以引入了偏置值和偏置的損失值。設骨幹網絡輸出的偏置值爲\(\hat{O}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\), 這個偏置值用L1 loss來訓練:

\[L_{offset}=\frac{1}{N}\sum_{p}|\hat{O}_{\tilde{p}}-(\frac{p}{R}-\tilde{p})| \]

p表明目標框中心點,R表明下采樣倍數4,\(\tilde{p}=\lfloor \frac{p}{R} \rfloor\), \(\frac{p}{R}-\tilde{p}\)表明誤差值。

2.3 size loss/wh loss

假設第k個目標,類別爲\(c_k\)的目標框的表示爲\((x_1^{(k)},y_1^{(k)},x_2^{(k)},y_2^{(k)})\),那麼其中心點座標位置爲\((\frac{x_1^{(k)}+x_2^{(k)}}{2}, \frac{y_1^{(k)}+y_2^{(k)}}{2})\), 目標的長和寬大小爲\(s_k=(x_2^{(k)}-x_1^{(k)},y_2^{(k)}-y_1^{(k)})\)。對長和寬進行訓練的是L1 Loss函數:

\[L_{size}=\frac{1}{N}\sum^{N}_{k=1}|\hat{S}_{pk}-s_k| \]

其中\(\hat{S}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\)是網絡輸出的結果。

2.4 CenterNet Loss

總體的損失函數是以上三者的綜合,而且分配了不一樣的權重。

\[L_{det}=L_k+\lambda_{size}L_{size}+\lambda_{offset}L_{offset} \]

其中\(\lambda_{size}=0.1, \lambda_{offsize}=1\)

3. 代碼解析

來自train.py中第173行開始進行loss計算:

# 獲得heat map, reg, wh 三個變量
hmap, regs, w_h_ = zip(*outputs)

regs = [
_tranpose_and_gather_feature(r, batch['inds']) for r in regs
]
w_h_ = [
_tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
]

# 分別計算loss
hmap_loss = _neg_loss(hmap, batch['hmap'])
reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])

# 進行loss加權,獲得最終loss
loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss

上述transpose_and_gather_feature函數具體實現以下,主要功能是將ground truth中計算獲得的對應中心點的值獲取。

def _tranpose_and_gather_feature(feat, ind):
  # ind表明的是ground truth中設置的存在目標點的下角標
  feat = feat.permute(0, 2, 3, 1).contiguous()# from [bs c h w] to [bs, h, w, c] 
  feat = feat.view(feat.size(0), -1, feat.size(3)) # to [bs, wxh, c]
  feat = _gather_feature(feat, ind)
  return feat

def _gather_feature(feat, ind, mask=None):
  # feat : [bs, wxh, c]
  dim = feat.size(2)
  # ind : [bs, index, c]
  ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  feat = feat.gather(1, ind) # 按照dim=1獲取ind
  if mask is not None:
    mask = mask.unsqueeze(2).expand_as(feat)
    feat = feat[mask]
    feat = feat.view(-1, dim)
  return feat

3.1 hmap loss代碼

調用:hmap_loss = _neg_loss(hmap, batch['hmap'])

def _neg_loss(preds, targets):
    ''' Modified focal loss. Exactly the same as CornerNet.
        Runs faster and costs a little bit more memory
        Arguments:
        preds (B x c x h x w)
        gt_regr (B x c x h x w)
    '''
    pos_inds = targets.eq(1).float()# heatmap爲1的部分是正樣本
    neg_inds = targets.lt(1).float()# 其餘部分爲負樣本

    neg_weights = torch.pow(1 - targets, 4)# 對應(1-Yxyc)^4

    loss = 0
    for pred in preds: # 預測值
        # 約束在0-1之間
        pred = torch.clamp(torch.sigmoid(pred), min=1e-4, max=1 - 1e-4)
        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred,
                                                   2) * neg_weights * neg_inds
        num_pos = pos_inds.float().sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if num_pos == 0:
            loss = loss - neg_loss # 只有負樣本
        else:
            loss = loss - (pos_loss + neg_loss) / num_pos
    return loss / len(preds)

\[L_k=\frac{-1}{N}\sum_{xyc}\begin{cases} (1-\hat{Y}_{xyc})^\alpha log(\hat{Y}_{xyc})& Y_{xyc}=1\\ (1-Y_{xyc})^\beta(\hat{Y}_{xyc})^\alpha log(1-\hat{Y}_{xyc})& otherwise \end{cases} \]

代碼和以上公式一一對應,pos表明正樣本,neg表明負樣本。

3.2 reg & wh loss代碼

調用:reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])

調用:w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])

def _reg_loss(regs, gt_regs, mask):
    mask = mask[:, :, None].expand_as(gt_regs).float()
    loss = sum(F.l1_loss(r * mask, gt_regs * mask, reduction='sum') /
               (mask.sum() + 1e-4) for r in regs)
    return loss / len(regs)

4. 參考

https://zhuanlan.zhihu.com/p/66048276

http://xxx.itp.ac.cn/pdf/1904.07850

相關文章
相關標籤/搜索