[GiantPandaCV導語] 本文主要講解CenterNet的loss,由偏置部分(reg loss)、熱圖部分(heatmap loss)、寬高(wh loss)部分三部分loss組成,附代碼實現。python
論文中提供了三個用於目標檢測的網絡,都是基於編碼解碼的結構構建的。網絡
這三個網絡中輸出內容都是同樣的,80個類別,2個預測中心對應的長和寬,2箇中心點的誤差。ide
# heatmap 輸出的tensor的通道個數是80,每一個通道表明對應類別的heatmap (hm): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1)) ) # wh 輸出是中心對應的長和寬,通道數爲2 (wh): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) ) # reg 輸出的tensor通道個數爲2,分別是w,h方向上的偏移量 (reg): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) )
輸入圖像\(I\in R^{W\times H\times 3}\), W爲圖像寬度,H爲圖像高度。網絡輸出的關鍵點熱圖heatmap爲\(\hat{Y}\in [0,1]^{\frac{W}{R}\times \frac{H}{R}\times C}\)其中,R表明獲得輸出相對於原圖的步長stride。C表明類別個數。函數
下面是CenterNet中核心loss公式:編碼
這個和Focal loss形式很類似,\(\alpha\)和\(\beta\)是超參數,N表明的是圖像關鍵點個數。spa
對於易分樣原本說,預測值\(\hat{Y}_{xyc}\)接近於1,\((1-\hat{Y}_{xyc})^\alpha\)就是一個很小的值,這樣loss就很小,起到了矯正做用。code
對於難分樣原本說,預測值\(\hat{Y}_{xyc}\)接近於0,$ (1-\hat{Y}_{xyc})^\alpha$就比較大,至關於加大了其訓練的比重。orm
上圖是一個簡單的示意,縱座標是\({Y}_{xyc}\),分爲A區(距離中心點較近,可是值在0-1之間)和B區(距離中心點很遠接近於0)。blog
對於A區來講,因爲其周圍是一個高斯核生成的中心,\(Y_{xyc}\)的值是從1慢慢變到0。ip
舉個例子(CenterNet中默認\(\alpha=2,\beta=4\)):
\(Y_{xyc}=0.8\)的狀況下,
若是\(\hat{Y}_{xyc}=0.99\),那麼loss=\((1-0.8)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。
若是\(\hat{Y}_{xyc}=0.8\), 那麼loss=\((1-0.8)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。
若是\(\hat{Y}_{xyc}=0.5\), 那麼loss=\((1-0.8)^4(0.5)^2log(1-0.5)\),
若是\(\hat{Y}_{xyc}=0.99\),那麼loss=\((1-0.5)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。
若是\(\hat{Y}_{xyc}=0.8\), 那麼loss=\((1-0.5)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。
若是\(\hat{Y}_{xyc}=0.5\), 那麼loss=\((1-0.5)^4(0.5)^2log(1-0.5)\),
總結一下:爲了防止預測值\(\hat{Y}_{xyc}\)太高接近於1,因此用\((\hat{Y}_{xyc})^\alpha\)來懲罰Loss。而\((1-Y_{xyc})^\beta\)這個參數距離中心越近,其值越小,這個權重是用來減輕懲罰力度。
對於B區來講,\(\hat{Y}_{xyc}\)的預測值理應是0,若是該值比較大好比爲1,那麼\((\hat{Y}_{xyc})^\alpha\)做爲權重會變大,懲罰力度也加大了。若是預測值接近於0,那麼\((\hat{Y}_{xyc})^\alpha\)會很小,讓其損失比重減少。對於\((1-Y_{xyc})^\beta\)來講,B區的值比較大,弱化了中心點周圍其餘負樣本的損失比重。
因爲三個骨幹網絡輸出的feature map的空間分辨率變爲原來輸入圖像的四分之一。至關於輸出feature map上一個像素點對應原始圖像的4x4的區域,這會帶來較大的偏差,所以引入了偏置值和偏置的損失值。設骨幹網絡輸出的偏置值爲\(\hat{O}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\), 這個偏置值用L1 loss來訓練:
p表明目標框中心點,R表明下采樣倍數4,\(\tilde{p}=\lfloor \frac{p}{R} \rfloor\), \(\frac{p}{R}-\tilde{p}\)表明誤差值。
假設第k個目標,類別爲\(c_k\)的目標框的表示爲\((x_1^{(k)},y_1^{(k)},x_2^{(k)},y_2^{(k)})\),那麼其中心點座標位置爲\((\frac{x_1^{(k)}+x_2^{(k)}}{2}, \frac{y_1^{(k)}+y_2^{(k)}}{2})\), 目標的長和寬大小爲\(s_k=(x_2^{(k)}-x_1^{(k)},y_2^{(k)}-y_1^{(k)})\)。對長和寬進行訓練的是L1 Loss函數:
其中\(\hat{S}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\)是網絡輸出的結果。
總體的損失函數是以上三者的綜合,而且分配了不一樣的權重。
其中\(\lambda_{size}=0.1, \lambda_{offsize}=1\)
來自train.py中第173行開始進行loss計算:
# 獲得heat map, reg, wh 三個變量 hmap, regs, w_h_ = zip(*outputs) regs = [ _tranpose_and_gather_feature(r, batch['inds']) for r in regs ] w_h_ = [ _tranpose_and_gather_feature(r, batch['inds']) for r in w_h_ ] # 分別計算loss hmap_loss = _neg_loss(hmap, batch['hmap']) reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks']) w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks']) # 進行loss加權,獲得最終loss loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss
上述transpose_and_gather_feature
函數具體實現以下,主要功能是將ground truth中計算獲得的對應中心點的值獲取。
def _tranpose_and_gather_feature(feat, ind): # ind表明的是ground truth中設置的存在目標點的下角標 feat = feat.permute(0, 2, 3, 1).contiguous()# from [bs c h w] to [bs, h, w, c] feat = feat.view(feat.size(0), -1, feat.size(3)) # to [bs, wxh, c] feat = _gather_feature(feat, ind) return feat def _gather_feature(feat, ind, mask=None): # feat : [bs, wxh, c] dim = feat.size(2) # ind : [bs, index, c] ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) # 按照dim=1獲取ind if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat
調用:hmap_loss = _neg_loss(hmap, batch['hmap'])
def _neg_loss(preds, targets): ''' Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory Arguments: preds (B x c x h x w) gt_regr (B x c x h x w) ''' pos_inds = targets.eq(1).float()# heatmap爲1的部分是正樣本 neg_inds = targets.lt(1).float()# 其餘部分爲負樣本 neg_weights = torch.pow(1 - targets, 4)# 對應(1-Yxyc)^4 loss = 0 for pred in preds: # 預測值 # 約束在0-1之間 pred = torch.clamp(torch.sigmoid(pred), min=1e-4, max=1 - 1e-4) pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss # 只有負樣本 else: loss = loss - (pos_loss + neg_loss) / num_pos return loss / len(preds)
代碼和以上公式一一對應,pos表明正樣本,neg表明負樣本。
調用:reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
調用:w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])
def _reg_loss(regs, gt_regs, mask): mask = mask[:, :, None].expand_as(gt_regs).float() loss = sum(F.l1_loss(r * mask, gt_regs * mask, reduction='sum') / (mask.sum() + 1e-4) for r in regs) return loss / len(regs)