Focal Loss筆記

論文:《Focal Loss for Dense Object Detection》git

Focal Loss 是何愷明設計的爲了解決one-stage目標檢測在訓練階段前景類和背景類極度不均衡(如1:1000)的場景的損失函數。它是由二分類交叉熵改造而來的。github

標準交叉熵app

其中,p是模型預測屬於類別y=1的機率。爲了方便標記,定義:函數

交叉熵CE重寫爲:字體

 

α-平衡交叉熵:spa

有一種解決類別不平衡的方法是引入一個值介於[0; 1]之間的權重因子α:當y=1時,取α; 當y=0時,取1-α。設計

這種方法,當y=0(即背景類)時,隨着α的增大,會對損失進行很大懲罰(下降權重),從而減輕背景類3d

太多對訓練的影響。code

相似Pt,可將α-CE重寫爲:blog

 

Focal Loss定義

雖然α-CE起到了平衡正負樣本的在損失函數值中的貢獻,可是它沒辦法區分難易樣本的樣本對損失的貢獻。所以就有了Focal Loss,定義以下:

其中,alpha和gamma均爲能夠調節的超參數。y'爲模型預測,其值介於(0-1)之間。

當y=1時,y'->1,表示easy positive,它對權重的貢獻->0;

當y=0是,y'->0,表示easy negative,它對權重的貢獻->0.

所以,Focal Loss不只下降了背景類的權重,還下降了easy positive/negative的權重。

gamma是對損失函數的調節,當gamma=0是,Focal Loss與α-CE等價。如下是gamma

對Focal Loss的調節。

 

 

Focal Loss的Pytorch實現(藍色字體)

如下Focal Loss=Focal Loss + Regress Loss;

代碼來自:https://github.com/yhenon/pytorch-retinanet

  1 import numpy as np
  2 import torch
  3 import torch.nn as nn
  4 
  5 def calc_iou(a, b):
  6     area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
  7 
  8     iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
  9     ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
 10 
 11     iw = torch.clamp(iw, min=0)
 12     ih = torch.clamp(ih, min=0)
 13 
 14     ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
 15 
 16     ua = torch.clamp(ua, min=1e-8)
 17 
 18     intersection = iw * ih
 19 
 20     IoU = intersection / ua
 21 
 22     return IoU
 23 
 24 class FocalLoss(nn.Module):
 25     #def __init__(self):
 26 
 27     def forward(self, classifications, regressions, anchors, annotations):
 28         alpha = 0.25
 29         gamma = 2.0
 30         batch_size = classifications.shape[0]
 31         classification_losses = []
 32         regression_losses = []
 33 
 34         anchor = anchors[0, :, :]
 35 
 36         anchor_widths  = anchor[:, 2] - anchor[:, 0]
 37         anchor_heights = anchor[:, 3] - anchor[:, 1]
 38         anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
 39         anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights
 40 
 41         for j in range(batch_size):
 42 
 43             classification = classifications[j, :, :]
 44             regression = regressions[j, :, :]
 45 
 46             bbox_annotation = annotations[j, :, :]
 47             bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
 48 
 49             if bbox_annotation.shape[0] == 0:
 50                 regression_losses.append(torch.tensor(0).float().cuda())
 51                 classification_losses.append(torch.tensor(0).float().cuda())
 52 
 53                 continue
 54 
 55             classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
 56 
 57             IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations
 58 
 59             IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1
 60 
 61             #import pdb
 62             #pdb.set_trace()
 63 
 64             # compute the loss for classification
 65             targets = torch.ones(classification.shape) * -1
 66             targets = targets.cuda()
 67 
 68             targets[torch.lt(IoU_max, 0.4), :] = 0
 69 
 70             positive_indices = torch.ge(IoU_max, 0.5)
 71 
 72             num_positive_anchors = positive_indices.sum()
 73 
 74             assigned_annotations = bbox_annotation[IoU_argmax, :]
 75 
 76             targets[positive_indices, :] = 0
 77             targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
 78 
 79             alpha_factor = torch.ones(targets.shape).cuda() * alpha
 80 
 81             alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor) 82 focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification) 83 focal_weight = alpha_factor * torch.pow(focal_weight, gamma) 84 85 bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification)) 86 87 # cls_loss = focal_weight * torch.pow(bce, gamma) 88 cls_loss = focal_weight * bce 89 90 cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())  91 
 92             classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
 93 
 94             # compute the loss for regression
 95 
 96             if positive_indices.sum() > 0:
 97                 assigned_annotations = assigned_annotations[positive_indices, :]
 98 
 99                 anchor_widths_pi = anchor_widths[positive_indices]
100                 anchor_heights_pi = anchor_heights[positive_indices]
101                 anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
102                 anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
103 
104                 gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
105                 gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
106                 gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
107                 gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights
108 
109                 # clip widths to 1
110                 gt_widths  = torch.clamp(gt_widths, min=1)
111                 gt_heights = torch.clamp(gt_heights, min=1)
112 
113                 targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
114                 targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
115                 targets_dw = torch.log(gt_widths / anchor_widths_pi)
116                 targets_dh = torch.log(gt_heights / anchor_heights_pi)
117 
118                 targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
119                 targets = targets.t()
120 
121                 targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
122 
123 
124                 negative_indices = 1 - positive_indices
125 
126                 regression_diff = torch.abs(targets - regression[positive_indices, :])
127 
128                 regression_loss = torch.where(
129                     torch.le(regression_diff, 1.0 / 9.0),
130                     0.5 * 9.0 * torch.pow(regression_diff, 2),
131                     regression_diff - 0.5 / 9.0
132                 )
133                 regression_losses.append(regression_loss.mean())
134             else:
135                 regression_losses.append(torch.tensor(0).float().cuda())
136 
137 return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)
相關文章
相關標籤/搜索