【本期推薦專題】物聯網從業人員必讀:華爲雲專家爲你詳細解讀LiteOS各模塊開發及其實現原理。
html
摘要:Focal Loss的兩個性質算是核心,其實就是用一個合適的函數去度量難分類和易分類樣本對總的損失的貢獻。git
本文分享自華爲雲社區《技術乾貨 | 基於MindSpore更好的理解Focal Loss》,原文做者:chengxiaoli。算法
今天更新一下愷明大神的Focal Loss,它是 Kaiming 大神團隊在他們的論文Focal Loss for Dense Object Detection提出來的損失函數,利用它改善了圖像物體檢測的效果。ICCV2017RBG和Kaiming大神的新做(https://arxiv.org/pdf/1708.02002.pdf)。網絡
使用場景
最近一直在作人臉表情相關的方向,這個領域的 DataSet 數量不大,並且每每存在正負樣本不均衡的問題。通常來講,解決正負樣本數量不均衡問題有兩個途徑:app
1. 設計採樣策略,通常都是對數量少的樣本進行重採樣函數
2. 設計 Loss,通常都是對不一樣類別樣本進行權重賦值學習
本文講的是第二種策略中的 Focal Loss。優化
理論分析
論文分析
咱們知道object detection按其流程來講,通常分爲兩大類。一類是two stage detector(如很是經典的Faster R-CNN,RFCN這樣須要region proposal的檢測算法),第二類則是one stage detector(如SSD、YOLO系列這樣不須要region proposal,直接回歸的檢測算法)。url
對於第一類算法能夠達到很高的準確率,可是速度較慢。雖然能夠經過減小proposal的數量或下降輸入圖像的分辨率等方式達到提速,可是速度並無質的提高。spa
對於第二類算法速度很快,可是準確率不如第一類。
因此目標就是:focal loss的出發點是但願one-stage detector能夠達到two-stage detector的準確率,同時不影響原有的速度。
So,Why?and result?
這是什麼緣由形成的呢?the Reason is:Class Imbalance(正負樣本不平衡),樣本的類別不均衡致使的。
咱們知道在object detection領域,一張圖像可能生成成千上萬的candidate locations,可是其中只有不多一部分是包含object的,這就帶來了類別不均衡。那麼類別不均衡會帶來什麼後果呢?引用原文講的兩個後果:
(1) training is inefficient as most locations are easy negatives that contribute no useful learning signal;
(2) en masse, the easy negatives can overwhelm training and lead to degenerate models.
意思就是負樣本數量太大(屬於背景的樣本),佔總的loss的大部分,並且可能是容易分類的,所以使得模型的優化方向並非咱們所但願的那樣。這樣,網絡學不到有用的信息,沒法對object進行準確分類。其實先前也有一些算法來處理類別不均衡的問題,好比OHEM(online hard example mining),OHEM的主要思想能夠用原文的一句話歸納:In OHEM each example is scored by its loss, non-maximum suppression (nms) is then applied, and a minibatch is constructed with the highest-loss examples。OHEM算法雖然增長了錯分類樣本的權重,可是OHEM算法忽略了容易分類的樣本。
所以針對類別不均衡問題,做者提出一種新的損失函數:Focal Loss,這個損失函數是在標準交叉熵損失基礎上修改獲得的。這個函數能夠經過減小易分類樣本的權重,使得模型在訓練時更專一於難分類的樣本。爲了證實Focal Loss的有效性,做者設計了一個dense detector:RetinaNet,而且在訓練時採用Focal Loss訓練。實驗證實RetinaNet不只能夠達到one-stage detector的速度,也能有two-stage detector的準確率。
公式說明
介紹focal loss,在介紹focal loss以前,先來看看交叉熵損失,這裏以二分類爲例,原來的分類loss是各個訓練樣本交叉熵的直接求和,也就是各個樣本的權重是同樣的。公式以下: 由於是二分類,p表示預測樣本屬於1的機率(範圍爲0-1),y表示label,y的取值爲{+1,-1}。當真實label是1,也就是y=1時,假如某個樣本x預測爲1這個類的機率p=0.6,那麼損失就是-log(0.6),注意這個損失是大於等於0的。若是p=0.9,那麼損失就是-log(0.9),因此p=0.6的損失要大於p=0.9的損失,這很容易理解。這裏僅僅以二分類爲例,多分類分類以此類推爲了方便,用pt代替p,以下公式2:。這裏的pt就是前面Figure1中的橫座標。
爲了表示簡便,咱們用p_t表示樣本屬於true class的機率。因此(1)式能夠寫成:
接下來介紹一個最基本的對交叉熵的改進,也將做爲本文實驗的baseline,既然one-stage detector在訓練的時候正負樣本的數量差距很大,那麼一種常見的作法就是給正負樣本加上權重,負樣本出現的頻次多,那麼就下降負樣本的權重,正樣本數量少,就相對提升正樣本的權重。所以能夠經過設定
的值來控制正負樣本對總的loss的共享權重。
取比較小的值來下降負樣本(多的那類樣本)的權重。
顯然前面的公式3雖然能夠控制正負樣本的權重,可是無法控制容易分類和難分類樣本的權重,因而就有了Focal Loss,這裏的γ稱做focusing parameter,γ>=0,稱爲調製係數:
爲何要加上這個調製係數呢?目的是經過減小易分類樣本的權重,從而使得模型在訓練時更專一於難分類的樣本。
經過實驗發現,繪製圖看以下Figure1,橫座標是pt,縱座標是loss。CE(pt)表示標準的交叉熵公式,FL(pt)表示focal loss中用到的改進的交叉熵。Figure1中γ=0的藍色曲線就是標準的交叉熵損失(loss)。
這樣就既作到了解決正負樣本不平衡,也作到了解決easy與hard樣本不平衡的問題。
結論
做者將類別不平衡做爲阻礙one-stage方法超過top-performing的two-stage方法的主要緣由。爲了解決這個問題,做者提出了focal loss,在交叉熵裏面用一個調整項,爲了將學習專一於hard examples上面,而且下降大量的easy negatives的權值。是同時解決了正負樣本不平衡以及區分簡單與複雜樣本的問題。
MindSpore代碼實現
咱們來看一下,基於MindSpore實現Focal Loss的代碼:
import mindspore import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore import nn class FocalLoss(_Loss): def __init__(self, weight=None, gamma=2.0, reduction='mean'): super(FocalLoss, self).__init__(reduction=reduction) # 校驗gamma,這裏的γ稱做focusing parameter,γ>=0,稱爲調製係數 self.gamma = validator.check_value_type("gamma", gamma, [float]) if weight is not None and not isinstance(weight, Tensor): raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight))) self.weight = weight # 用到的mindspore算子 self.expand_dims = P.ExpandDims() self.gather_d = P.GatherD() self.squeeze = P.Squeeze(axis=1) self.tile = P.Tile() self.cast = P.Cast() def construct(self, predict, target): targets = target # 對輸入進行校驗 _check_ndim(predict.ndim, targets.ndim) _check_channel_and_shape(targets.shape[1], predict.shape[1]) _check_predict_channel(predict.shape[1]) # 將logits和target的形狀更改成num_batch * num_class * num_voxels. if predict.ndim > 2: predict = predict.view(predict.shape[0], predict.shape[1], -1) # N,C,H,W => N,C,H*W targets = targets.view(targets.shape[0], targets.shape[1], -1) # N,1,H,W => N,1,H*W or N,C,H*W else: predict = self.expand_dims(predict, 2) # N,C => N,C,1 targets = self.expand_dims(targets, 2) # N,1 => N,1,1 or N,C,1 # 計算對數機率 log_probability = nn.LogSoftmax(1)(predict) # 只保留每一個voxel的地面真值類的對數機率值。 if target.shape[1] == 1: log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32)) log_probability = self.squeeze(log_probability) # 獲得機率 probability = F.exp(log_probability) if self.weight is not None: convert_weight = self.weight[None, :, None] # C => 1,C,1 convert_weight = self.tile(convert_weight, (targets.shape[0], 1, targets.shape[2])) # 1,C,1 => N,C,H*W if target.shape[1] == 1: convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32)) # selection of the weights => N,1,H*W convert_weight = self.squeeze(convert_weight) # N,1,H*W => N,H*W # 將對數機率乘以它們的權重 probability = log_probability * convert_weight # 計算損失小批量 weight = F.pows(-probability + 1.0, self.gamma) if target.shape[1] == 1: loss = (-weight * log_probability).mean(axis=1) # N else: loss = (-weight * targets * log_probability).mean(axis=-1) # N,C return self.get_loss(loss)
使用方法以下:
from mindspore.common import dtype as mstype from mindspore import nn from mindspore import Tensor predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) target = Tensor([[1], [1], [0]], mstype.int32) focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean') output = focalloss(predict, target) print(output) 0.33365273
Focal Loss的兩個重要性質
1. 當一個樣本被分錯的時候,pt是很小的,那麼調製因子(1-Pt)接近1,損失不被影響;當Pt→1,因子(1-Pt)接近0,那麼分的比較好的(well-classified)樣本的權值就被調低了。所以調製係數就趨於1,也就是說相比原來的loss是沒有什麼大的改變的。當pt趨於1的時候(此時分類正確並且是易分類樣本),調製係數趨於0,也就是對於總的loss的貢獻很小。
2. 當γ=0的時候,focal loss就是傳統的交叉熵損失,當γ增長的時候,調製係數也會增長。 專一參數γ平滑地調節了易分樣本調低權值的比例。γ增大能加強調製因子的影響,實驗發現γ取2最好。直覺上來講,調製因子減小了易分樣本的損失貢獻,拓寬了樣例接收到低損失的範圍。當γ必定的時候,好比等於2,同樣easy example(pt=0.9)的loss要比標準的交叉熵loss小100+倍,當pt=0.968時,要小1000+倍,可是對於hard example(pt < 0.5),loss最多小了4倍。這樣的話hard example的權重相對就提高了不少。
這樣就增長了那些誤分類的重要性Focal Loss的兩個性質算是核心,其實就是用一個合適的函數去度量難分類和易分類樣本對總的損失的貢獻。