【轉載】深度學習中softmax交叉熵損失函數的理解

深度學習中softmax交叉熵損失函數的理解

版權聲明:本文爲博主原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處連接和本聲明。
本文連接: https://blog.csdn.net/lilong117194/article/details/81542667

1. softmax層的做用

經過神經網絡解決多分類問題時,最經常使用的一種方式就是在最後一層設置n個輸出節點,不管在淺層神經網絡仍是在CNN中都是如此,好比,在AlexNet中最後的輸出層有1000個節點,即使是ResNet取消了全鏈接層,但1000個節點的輸出層還在。git

通常狀況下,最後一個輸出層的節點個數與分類任務的目標數相等。 
假設最後的節點數爲N,那麼對於每個樣例,神經網絡能夠獲得一個N維的數組做爲輸出結果,數組中每個維度會對應一個類別。在最理想的狀況下,若是一個樣本屬於k,那麼這個類別所對應的的輸出節點的輸出值應該爲1,而其餘節點的輸出都爲0,即 [0,0,1,0,.0,0][0,0,1,0,….0,0],這個數組也就是樣本的Label,是神經網絡最指望的輸出結果,但實際是這樣的輸出[0.01,0.01,0.6,....0.02,0.01][0.01,0.01,0.6,....0.02,0.01],這實際上是在原始輸出的基礎上加入了softmax的結果,原始的輸出是輸入的數值作了複雜的加權和與非線性處理以後的一個值而已,這個值能夠是任意的值,可是通過softmax層後就成了一個機率值,並且機率和爲1。 
假設神經網絡的原始輸出爲y_1,y_2,….,y_n,那麼通過Softmax迴歸處理以後的輸出爲 : 
數組

 
y=softmax(yi)=eyinj=1eyjy′=softmax(yi)=eyi∑j=1neyj

以上能夠看出: y=1∑y′=1 
這也是爲何softmax層的每一個節點的輸出值成爲了機率和爲1的機率分佈。

 

2. 交叉熵損失函數的數學原理

上面說過實際的指望輸出,也就是標籤是[0,0,1,0,.0,0][0,0,1,0,….0,0]這種形式,而實際的輸出是[0.01,0.01,0.6,....0.02,0.01][0.01,0.01,0.6,....0.02,0.01]這種形式,這時按照常理就須要有一個損失函數來斷定實際輸出和指望輸出的差距,交叉熵就是用來斷定實際的輸出與指望的輸出的接近程度!下面就簡單介紹下交叉熵的原理。markdown

交叉熵刻畫的是實際輸出(機率)與指望輸出(機率)的距離,也就是交叉熵的值越小,兩個機率分佈就越接近。假設機率分佈p爲指望輸出(標籤),機率分佈q爲實際輸出,H(p,q)爲交叉熵。網絡

  • 第一種交叉熵損失函數的形式: 
     
    H(p,q)=xp(x)logq(x)H(p,q)=−∑xp(x)logq(x)

舉個例子: 
假設N=3,指望輸出爲p=(1,0,0),實際輸出q1=(0.5,0.2,0.3)q2=(0.8,0.1,0.1)q1=(0.5,0.2,0.3),q2=(0.8,0.1,0.1),這裏的q1,q2兩個輸出分別表明在不一樣的神經網絡參數下的實際輸出,經過計算其對應的交叉熵來優化神經網絡參數,計算過程: 
H(p,q1)=1(1×log0.5+0×log0.2+0×log0.3)H(p,q1)=−1(1×log0.5+0×log0.2+0×log0.3) 
假設結果:H(p,q1)=0.3H(p,q1)=0.3 
H(p,q2)=1(1×log0.8+0×log0.1+0×log0.1)H(p,q2)=−1(1×log0.8+0×log0.1+0×log0.1) 
假設結果:H(p,q2)=0.1H(p,q2)=0.1 
這時獲得了q2q2是相對正確的分類結果。
session

  • 第二種交叉熵損失函數形式: 
     
    H(p,q)=x(p(x)logq(x)+(1p(x))log(1q(x)))H(p,q)=−∑x(p(x)logq(x)+(1−p(x))log(1−q(x)))

    下面簡單推到其過程: 
    咱們知道,在二分類問題模型:例如邏輯迴歸「Logistic Regression」、神經網絡「Neural Network」等,真實樣本的標籤爲 [0,1],分別表示負類和正類。模型的最後一般會通過一個 Sigmoid 函數,輸出一個機率值,這個機率值反映了預測爲正類的可能性:機率越大,可能性越大。 
    Sigmoid 函數的表達式和圖形以下所示:g(s)=11+esg(s)=11+e−s 
    其中 s 是模型上一層的輸出,Sigmoid 函數有這樣的特色:s = 0 時,g(s) = 0.5;s >> 0 時, g ≈ 1,s << 0 時,g ≈ 0。顯然,g(s) 將前一級的線性輸出映射到 [0,1] 之間的數值機率上。 
    其中預測輸出即 Sigmoid 函數的輸出g(s)表徵了當前樣本標籤爲 1 的機率: 
    P(y=1|x)=y^P(y=1|x)=y^ 
    p(y=0|x)=1y^p(y=0|x)=1−y^ 
    這個時候從極大似然性的角度出發,把上面兩種狀況整合到一塊兒: 
    p(y|x)=y^y(1y^)(1y)p(y|x)=y^y(1−y^)(1−y) 
    這個函數式表徵的是: 
    當真實樣本標籤 y = 1 時,上面式子第二項就爲 1,機率等式轉化爲: 
    P(y=1|x)=y^P(y=1|x)=y^ 
    當真實樣本標籤 y = 0 時,上面式子第一項就爲 1,機率等式轉化爲: 
    P(y=0|x)=1y^P(y=0|x)=1−y^ 
    兩種狀況下機率表達式跟以前的徹底一致,只不過咱們把兩種狀況整合在一塊兒了。那這個時候應用極大似然估計應該獲得的是全部的機率值乘積應該最大,即: 
    L=Ni=1y^yii(1y^i)(1yi)L=∑i=1Ny^iyi(1−y^i)(1−yi) 
    引入log函數後獲得: 
    L=log(L)=Ni=1yilogy^i+(1yi)log(1y^i)L′=log(L)=∑i=1Nyilogy^i+(1−yi)log(1−y^i) 
    這時令loss=-log(L)=-L',也就是損失函數越小越好,而此時也就是 L'越大越好。

而在實際的使用訓練過程當中,數據每每是組合成爲一個batch來使用,因此對用的神經網絡的輸出應該是一個m*n的二維矩陣,其中m爲batch的個數,n爲分類數目,而對應的Label也是一個二維矩陣,仍是拿上面的數據,組合成一個batch=2的矩陣 函數

 
q=[0.50.80.20.10.30.1]q=[0.50.20.30.80.10.1]

 
p=[110000]p=[100100]

根據第一種交叉熵的形式獲得: 
 
H(p,q)=[0.30.1]H(p,q)=[0.30.1]

而對於一個batch,最後取平均爲0.2。

 

3. 在TensorFlow中實現交叉熵

在TensorFlow能夠採用這種形式:學習

cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) 
  • 1

其中y_表示指望的輸出,y表示實際的輸出(機率值),*爲矩陣元素間相乘,而不是矩陣乘。 
而且經過tf.clip_by_value函數能夠將一個張量中的數值限制在一個範圍以內,這樣能夠避免一些運算錯誤(好比log0是無效的),tf.clip_by_value函數是爲了限制輸出的大小,爲了不log0爲負無窮的狀況,將輸出的值限定在(1e-10, 1.0)之間,其實1.0的限制是沒有意義的,由於機率怎麼會超過1呢。好比:優化

import tensorflow as tf

v=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0]]) with tf.Session() as sess: print(tf.clip_by_value(v,2.5,4.5).eval(session=sess))
  • 1
  • 2
  • 3
  • 4
  • 5

結果:ui

[[2.5 2.5 3. ] [4. 4.5 4.5]]
  • 1
  • 2

上述代碼實現了第一種形式的交叉熵計算,須要說明的是,計算的過程其實和上面提到的公式有些區別,按照上面的步驟,平均交叉熵應該是先計算batch中每個樣本的交叉熵後取平均計算獲得的,而利用tf.reduce_mean函數其實計算的是整個矩陣的平均值,這樣作的結果會有差別,可是並不改變實際意義。atom

import tensorflow as tf

v=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0]]) with tf.Session() as sess: # 輸出3.5 print(tf.reduce_mean(v).eval())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

因爲在神經網絡中,交叉熵經常與Sorfmax函數組合使用,因此TensorFlow對其進行了封裝,即:

cross_entropy = tf.nn.sorfmax_cross_entropy_with_logits(y_ ,y)
  • 1

與第一個代碼的區別在於,這裏的y用神經網絡最後一層的原始輸出就行了,而不是通過softmax層的機率值。

參考:http://www.javashuo.com/article/p-qrfavtho-ev.html 
https://blog.csdn.net/chaipp0607/article/details/73392175

相關文章
相關標籤/搜索