直觀理解爲何分類問題用交叉熵損失而不用均方偏差損失?

博客:blog.shinelee.me | 博客園 | CSDN網絡

交叉熵損失與均方偏差損失

常規分類網絡最後的softmax層以下圖所示,傳統機器學習方法以此類比,機器學習

https://stats.stackexchange.com/questions/273465/neural-network-softmax-activation

一共有\(K\)類,令網絡的輸出爲\([\hat{y}_1,\dots, \hat{y}_K]\),對應每一個類別的機率,令label爲 \([y_1, \dots, y_K]\)。對某個屬於\(p\)類的樣本,其label中\(y_p=1\)\(y_1, \dots, y_{p-1}, y_{p+1}, \dots, y_K\)均爲0。wordpress

對這個樣本,交叉熵(cross entropy)損失
\[ \begin{aligned}L &= - (y_1 \log \hat{y}_1 + \dots + y_K \log \hat{y}_K) \\&= -y_p \log \hat{y}_p \\ &= - \log \hat{y}_p\end{aligned} \]
均方偏差損失(mean squared error,MSE)
\[ \begin{aligned}L &= (y_1 - \hat{y}_1)^2 + \dots + (y_K - \hat{y}_K)^2 \\&= (1 - \hat{y}_p)^2 + (\hat{y}_1^2 + \dots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \dots + \hat{y}_K^2)\end{aligned} \]
\(m\)個樣本的損失爲
\[ \ell = \frac{1}{m} \sum_{i=1}^m L_i \]
對比交叉熵損失與均方偏差損失,只看單個樣本的損失便可,下面從兩個角度進行分析。函數

損失函數角度

損失函數是網絡學習的指揮棒,它引導着網絡學習的方向——能讓損失函數變小的參數就是好參數。學習

因此,損失函數的選擇和設計要能表達你但願模型具備的性質與傾向。spa

對比交叉熵和均方偏差損失,能夠發現,二者均在\(\hat{y} = y = 1\)時取得最小值0,但在實踐中\(\hat{y}_p\)只會趨近於1而不是剛好等於1,在\(\hat{y}_p < 1\)的狀況下,.net

  • 交叉熵只與label類別有關,\(\hat{y}_p\)越趨近於1越好
  • 均方偏差不只與\(\hat{y}_p\)有關,還與其餘項有關,它但願\(\hat{y}_1, \dots, \hat{y}_{p-1}, \hat{y}_{p+1}, \dots, \hat{y}_K\)越平均越好,即在\(\frac{1-\hat{y}_p}{K-1}\)時取得最小值

分類問題中,對於類別之間的相關性,咱們缺少先驗。設計

雖然咱們知道,與「狗」相比,「貓」和「老虎」之間的類似度更高,可是這種關係在樣本標記之初是難以量化的,因此label都是one hot。3d

在這個前提下,均方偏差損失可能會給出錯誤的指示,好比貓、老虎、狗的3分類問題,label爲\([1, 0, 0]\),在均方偏差看來,預測爲\([0.8, 0.1, 0.1]\)要比\([0.8, 0.15, 0.05]\)要好,即認爲平均總比有傾向性要好,但這有悖咱們的常識

對交叉熵損失,既然類別間複雜的類似度矩陣是難以量化的,索性只能關注樣本所屬的類別,只要\(\hat{y}_p\)越接近於1就好,這顯示是更合理的。

softmax反向傳播角度

softmax的做用是將\((-\infty, +\infty)\)的幾個實數映射到\((0,1)\)之間且之和爲1,以得到某種機率解釋。

令softmax函數的輸入爲\(z\),輸出爲\(\hat{y}\),對結點\(p\)有,
\[ \hat{y}_p = \frac{e^{z_p}}{\sum_{k=1}^K e^{z_k}} \]
\(\hat{y}_p\)不只與\(z_p\)有關,還與\(\{z_k | k\neq p\}\)有關,這裏僅看$z_p $,則有
\[ \frac{\partial \hat{y}_p}{\partial z_p} = \hat{y}_p(1-\hat{y}_p) \]
\(\hat{y}_p\)爲正確分類的機率,爲0時表示分類徹底錯誤,越接近於1表示越正確。根據鏈式法則,按理來說,對與\(z_p\)相連的權重,損失函數的偏導會含有\(\hat{y}_p(1-\hat{y}_p)\)這一因子項,\(\hat{y}_p = 0\)分類錯誤,但偏導爲0,權重不會更新,這顯然不對——分類越錯誤越須要對權重進行更新

交叉熵損失
\[ \frac{\partial L}{\partial \hat{y}_p} = -\frac{1}{\hat{y}_p} \]
則有
\[ \frac{\partial L}{\partial \hat{z}_p} = \frac{\partial L}{\partial \hat{y}_p} \cdot \frac{\partial \hat{y}_p}{\partial z_p} = \hat{y}_p - 1 \]
剛好將\(\hat{y}_p(1-\hat{y}_p)\)中的\(\hat{y}_p\)消掉,避免了上述情形的發生,且\(\hat{y}_p\)越接近於1,偏導越接近於0,即分類越正確越不須要更新權重,這與咱們的指望相符。

而對均方偏差損失
\[ \frac{\partial L}{\partial \hat{y}_p} = -2(1-\hat{y}_p)=2(\hat{y}_p - 1) \]
則有,
\[ \frac{\partial L}{\partial \hat{z}_p} = \frac{\partial L}{\partial \hat{y}_p} \cdot \frac{\partial \hat{y}_p}{\partial z_p} = -2 \hat{y}_p (1 - \hat{y}_p)^2 \]
顯然,仍會發生上面所說的狀況——\(\hat{y}_p = 0\)分類錯誤,但不更新權重

綜上,對分類問題而言,不管從損失函數角度仍是softmax反向傳播角度,交叉熵都比均方偏差要好。

參考

相關文章
相關標籤/搜索