kaldi中CD-DNN-HMM網絡參數更新公式手寫推導

在基於DNN-HMM的語音識別中,DNN的做用跟GMM是同樣的,即它是取代GMM的,具體做用是算特徵值對每一個三音素狀態的機率,算出來哪一個最大這個特徵值就對應哪一個狀態。只不過之前是用GMM算的,如今用DNN算了。這是典型的多分類問題,因此輸出層用的激活函數是softmax,損失函數用的是cross entropy(交叉熵)。不用均方差作損失函數的緣由是在分類問題上它是非凸函數,不能保證全局最優解(只有凸函數才能保證全局最優解)。Kaldi中也支持DNN-HMM,它還依賴於上下文(context dependent, CD),因此叫CD-DNN-HMM。在kaldi的nnet1中,特徵提取用filterbank,每幀40維數據,默認取當前幀先後5幀加上當前幀共11幀做爲輸入,因此輸入層維數是440(440 = 40*11)。同時默認有4個隱藏層,每層1024個網元,激活函數是sigmoid。今天咱們看看網絡的各類參數是怎麼獲得的(手寫推導)。因爲真正的網絡比較複雜,爲了推導方便這裏對其進行了簡化,只有一個隱藏層,每層的網元均爲3,同時只有weight沒有bias。這樣網絡以下圖:網絡

上圖中輸入層3個網元爲i1/i2/i3(i表示input),隱藏層3個網元爲h1/h2/h3(h表示hidden),輸出層3個網元爲o1/o2/o3(o表示output)。隱藏層h1的輸入爲 (q11等表示輸入層和隱藏層之間的權值),輸出爲。輸出層o1的輸入爲(w11等表示隱藏層和輸出層之間的權值),輸出爲。其餘可相似推出。損失函數用交叉熵。今天咱們看看網絡參數(以隱藏層和輸出層之間的w11以及輸入層和隱藏層之間的q11爲例)在每次迭代訓練後是怎麼更新的。先看隱藏層和輸出層之間的w11。函數

 

1,隱藏層和輸出層之間的w11的更新blog

 

 先分別求三個導數的值:input

 

 因此最終的w11更新公式以下圖:bfc

 

2,輸入層和隱藏層之間的q11的更新im

 

先分別求三個導數的值:d3

 

因此最終的q11更新公式以下圖:數據

 

以上的公式推導中若有錯誤,煩請指出,很是感謝!filter

相關文章
相關標籤/搜索