Focal loss論文解析

Focal loss是目標檢測領域的一篇十分經典的論文,它經過改造損失函數提高了一階段目標檢測的性能,背後關於類別不平衡的學習的思想值得咱們深刻地去探索和學習。正負樣本失衡不單單在目標檢測算法中會出現,在別的機器學習任務中一樣會出現,這篇論文爲咱們解決相似問題提供了一個很好的啓發,因此我認爲不管是否從事目標檢測領域相關工做,均可以來看一看這篇好論文。git

論文的關鍵性改進在於對損失函數的改造以及對參數初始化的設置。github

首先是對損失函數的改造。論文中指出,限制目標檢測網絡性能的一個關鍵因素是類別不平衡。二階段目標檢測算法相比於一階段目標檢測算法的優勢在於,二階段的目標檢測算法經過候選框篩選算法(proposal stage)過濾了大部分背景樣本(負樣本),使得正負樣本比例適中;而一階段的目標檢測算法中,須要處理大量的負樣本,使得包含目標的正樣本信息被淹沒。這使得一階段目標檢測算法的識別準確度上比不上二階段的目標檢測算法。算法

爲了解決這個問題,Focal loss使用了動態加權的思想,對於置信度高的樣本,損失函數進行降權;對於置信度低的樣本,損失函數進行加權,使得網絡在反向傳播時,置信度低的樣本可以提供更大的梯度佔比,即從未學習好的樣本中獲取更多的信息(就像高中時期的錯題本同樣,對於容易錯的題目,包含了更多的信息量,須要更加關注這種題目;而對於屢屢正確的題目,能夠少點關注,說明已經掌握了這類型的題目)
其巧妙之處就在於,經過了網絡自己輸出的機率值(置信度)去構建權重,實現了自適應調整權重的目的。網絡

公式的講解

Focal loss是基於交叉熵損失構建的,二元交叉熵的公式爲app

\[\mathrm{CE}(p, y)=\left\{\begin{array}{ll} -\log (p) & \text { if } y = +1 \\ -\log (1-p) & \text { y = -1 } \end{array}\right. \]

爲了方便表示,定義\(p_t\)爲分類正確的機率機器學習

\[p_{t}=\left\{\begin{array}{ll} p & \text { if } y = +1 \\ 1-p & \text { y = -1 } \end{array}\right. \]

則交叉熵損失表示爲\(CE(p,y)=CE(p_t)=-log(p_t)\)。如前文所述,經過置信度對損失進行縮放獲得Focal loss。函數

\[FL(p_t)=-\alpha_t(1-p_t)^\gamma log(p_t)= \alpha_t(1-p_t)^\gamma\times CE(p_t) \]

其中,\(\alpha_{1}=\left\{\begin{array}{ll} \alpha & \text { if } y = +1 \\ 1-\alpha & \text { y = -1 } \end{array}\right.\)爲縮放乘數(直接調整正負樣本的權重),\(\gamma\)爲縮放因子,\((1-p_t)\)能夠理解爲分類錯誤的機率。公式中起到關鍵做用的部分是\((1-p_t)^\gamma\)。爲了給易分樣本降權,一般設置\(\gamma>1\)
對於正確分類的樣本,\(p_t \to 1 \Rightarrow(1-p_t) \to 0\),受到\(\gamma\)的影響很大,\((1-p_t)^\gamma \approx 0\)
對於錯誤分類的樣本,\(p_t \to 0 \Rightarrow(1-p_t) \to 1\),受到\(\gamma\)的影響較小,\((1-p_t)^\gamma \approx (1-p_t)\),對於難分樣本的降權較小。
Focal loss本質上是經過置信度給易分樣本進行更多的降權,對難分樣本進行更少的降權,實現對難分樣本的關注。性能

參數初始化

論文中還有一個比較重要的點是對於子網絡最後一層權重的初始化方式,關係到網絡初期訓練的性能。這裏結合論文和我看過的一篇博文進行詳細的展開。常規的深度學習網絡初始化算法,使用的分佈是高斯分佈,根據機率論知識,兩個高斯分佈的變量的乘積仍然服從高斯分佈。假設權重\(w\sim N(\mu_w,\sigma_w^2)\),最後一層的特徵\(x\sim N(\mu_x,\sigma_x^2)\),則\(wx \sim N(\mu_{wx},\sigma_{wx}^2)\)學習

\[\mu_{wx}=\frac{\mu_w \sigma_x^2+\mu_x \sigma_w^2}{\sigma_x^2+\sigma_w^2}\\ \sigma_{wx}=\frac{\sigma_x^2\sigma_w^2}{\sigma_x^2+\sigma_w^2} \]

其中\(x\)的分佈取決於網絡的結果,\(w\)的分佈參數爲\(\mu_w=0,\sigma_w^2=10^{-4}\),只需\(x\)的分佈參數知足\(\sigma_x^2\gg 10^{-4},\sigma_x^2\gg10^{-4}\mu_x\)成立,有以下的不等式。(通常狀況下,這兩個條件是成立的。)spa

\[\mu_{wx}=\frac{\mu_w \sigma_x^2+\mu_x \sigma_w^2}{\sigma_x^2+\sigma_w^2}=\frac{10^{-4}\mu_x}{\sigma_x^2+10^{-4}}\ll\frac{10^{-4}\mu_x}{10^{-4}\mu_x+10^{-4}}=\frac{1}{1+\frac{1}{\mu_x}}\approx0 \text{因爲}\mu_x\text{通常爲分數(網絡的輸入通過歸一化到0至1,隨着網絡加深的連乘,分數會愈來愈小)}\\ \sigma_{wx}=\frac{\sigma_x^2\sigma_w^2}{\sigma_x^2+\sigma_w^2}=\frac{10^{-4}}{1+\frac{10^{-4}}{\sigma_x^2}}\approx10^{-4} \text{因爲}\sigma_x^2\gg10^{-4} \]

根據上述推導,\(wx\)服從一個均值爲0,方差很小的高斯分佈,能夠在很大機率上認爲它就等於0,因此網絡最後一層的輸出爲

\[p=sigmoid(wx+b)=sigmoid(b)=\frac{1}{1+e^{-b}}=\pi \]

\(\pi\)爲網絡初始化時輸出爲正類的機率,設置爲一個很小的值(0.01),則網絡在訓練初期,將樣本都劃分爲負類,對於正類\(p_t=0.01\),負類\(p_t=0.99\),則訓練初期,正類都被大機率錯分,負類都被大機率正確分類,因此在訓練初期更加關注正類,避免初期的正類信息被淹沒在負類信息中。

總結

總的來講,Focal loss經過對損失函數的簡單改進,實現了一種自適應的困難樣本挖掘策略,使得網絡在學習過程當中關注更難學習的樣本,在必定程度上解決了正負樣本分佈不均衡的問題(因爲正負樣本分佈不均衡,對於稀少的正樣本學習不足,致使正樣本廣泛表現爲難分樣本)。

參考資料

論文原文
一篇不錯的解析博客

相關文章
相關標籤/搜索