做者 | 張皓git
RNN是深度學習中用於處理時序數據的關鍵技術, 目前已在天然語言處理, 語音識別, 視頻識別等領域取得重要突破, 然而梯度消失現象制約着RNN的實際應用。LSTM和GRU是兩種目前廣爲使用的RNN變體,它們經過門控機制很大程度上緩解了RNN的梯度消失問題,可是它們的內部結構看上去十分複雜,使得初學者很難理解其中的原理所在。本文介紹」三次簡化一張圖」的方法,對LSTM和GRU的內部結構進行分析。該方法很是通用,適用於全部門控機制的原理分析。
github
RNN (recurrent neural networks, 注意不是recursiveneural networks)提供了一種處理時序數據的方案。和n-gram只能根據前n-1個詞來預測當前詞不一樣, RNN理論上能夠根據以前全部的詞預測當前詞。在每一個時刻, 隱層的輸出ht依賴於當前詞輸入xt和前一時刻的隱層狀態ht-1:
web
其中:=表示"定義爲", sigm表明sigmoid函數sigm(z):=1/(1+exp(-z)), Wxh和Whh是可學習的參數。結構見下圖:
微信
圖中左邊是輸入,右邊是輸出。xt是當前詞,ht-1記錄了上文的信息。xt和ht-1在分別乘以Wxh和Whh以後相加,再通過tanh非線性變換,最終獲得ht。
網絡
在反向傳播時,咱們須要將RNN沿時間維度展開,隱層梯度在沿時間維度反向傳播時須要反覆乘以參數。所以, 儘管理論上RNN能夠捕獲長距離依賴, 但實際應用中,根據譜半徑(spectralradius)的不一樣,RNN將會面臨兩個挑戰:梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。梯度爆炸會影響訓練的收斂,甚至致使網絡不收斂;而梯度消失會使網絡學習長距離依賴的難度增長。這二者相比, 梯度爆炸相對比較好處理,能夠用梯度裁剪(gradientclipping)來解決,而如何緩解梯度消失是RNN及幾乎其餘全部深度學習方法研究的關鍵所在。app
LSTM經過設計精巧的網絡結構來緩解梯度消失問題,其數學上的形式化表示以下:ide
其中表明逐元素相乘。這個公式看起來彷佛十分複雜,爲了更好的理解LSTM的機制, 許多人用圖來描述LSTM的計算過程, 好比下面的幾張圖:
函數
彷佛看完了這些圖以後,你對LSTM的理解仍是一頭霧水? 這是由於這些圖想把LSTM的全部細節一次性都展現出來,可是忽然暴露這麼多的細節會使你眼花繚亂,從而無處下手。
學習
所以,本文提出的方法旨在簡化門控機制中不重要的部分,從而更關注在LSTM的核心思想。整個過程是「三次簡化一張圖」,具體流程以下:優化
第一次簡化: 忽略門控單元i,f,o的來源。3個門控單元的計算方法徹底相同, 都是由輸入通過線性映射獲得的, 區別只是計算的參數不一樣。這樣作的目的是爲了梯度反向傳導時能對門控單元進行更新。這不是LSTM的核心思想, 在進行理解時,咱們能夠假定各門控單元是給定的。
第二次簡化: 考慮一維狀況。LSTM中對各維是獨立進行門控的,因此爲了理解方便,咱們只須要考慮一維狀況。
第三次簡化: 各門控單元0/1輸出。 門控單元輸出是[0,1]實數區間的緣由是階躍激活函數沒法反向傳播進行優化, 因此各門控單元使用sigmoid激活函數去近似階躍函數。 所以, 爲了理解方便, 咱們只須要考慮理想狀況, 即各門控單元是{0,1}二值輸出的,即門控單元扮演了電路中」開關」的角色, 用於控制信息傳輸的通斷。
一張圖: 將三次簡化的結果用」電路圖」表述出來,左邊是輸入,右邊是輸出。另外須要特別注意的是LSTM中的c實質上起到了RNN中h的做用, 這點在其餘文獻資料中不常被提到。最終結果以下:
和RNN相同的是,網絡接受兩個輸入,獲得一個輸出。不一樣之處在於, LSTM中經過3個門控單元來對記憶單元c的信息進行交互。
根據這張圖,咱們能夠對LSTM中各單元做用進行分析:
輸入門it: it控制當前詞xt的信息融入記憶單元ct。在理解一句話時,當前詞xt可能對整句話的意思很重要,也可能並不重要。輸入門的目的就是判斷當前詞xt對全局的重要性。當it開關打開的時候,網絡將不考慮當前輸入xt。
遺忘門ft: ft控制上一時刻記憶單元ct-1的信息融入記憶單元ct。在理解一句話時,當前詞xt可能繼續延續上文的意思繼續描述,也可能從當前詞xt開始描述新的內容,與上文無關。和輸入門it相反, ft不對當前詞xt的重要性做判斷, 而判斷的是上一時刻的記憶單元ct-1對計算當前記憶單元ct的重要性。當ft開關打開的時候,網絡將不考慮上一時刻的記憶單元ct-1。
輸出門ot: 輸出門的目的是從記憶單元ct產生隱層單元ht。並非ct中的所有信息都和隱層單元ht有關,ct可能包含了不少對ht無用的信息,所以, ot的做用就是判斷ct中哪些部分是對ht有用的,哪些部分是無用的。
記憶單元ct:ct綜合了當前詞xt和前一時刻記憶單元ct-1的信息。這和ResNet中的殘差逼近思想十分類似,經過從ct-1到ct的」短路鏈接」, 梯度得已有效地反向傳播。 當ft處於閉合狀態時, ct的梯度能夠直接沿着最下面這條短路線傳遞到ct-1,不受參數W的影響,這是LSTM能有效地緩解梯度消失現象的關鍵所在。
GRU是另外一種十分主流的RNN衍生物。RNN和LSTM都是在設計網絡結構用於緩解梯度消失問題, 只不過是網絡結構有所不一樣。GRU在數學上的形式化表示以下:
爲了理解GRU的設計思想,咱們再一次運用「三次簡化一張圖」的方法來進行分析:
第一次簡化: 忽略門控單元z, r的來源。
第二次簡化: 考慮一維狀況。
第三次簡化: 各門控單元0/1輸出。這裏和LSTM略有不一樣的地方在於,GRU須要引入一個」單刀雙擲開關」。
一張圖: 把三次簡化的結果用」電路圖」表述出來,左輸入,右輸出:
與LSTM相比,GRU將輸入門it和遺忘門ft融合成單一的更新門zt,而且融合了記憶單元ct和隱層單元ht,因此結構上比LSTM更簡單一些。
根據這張圖,咱們能夠對GRU的各單元做用進行分析:
重置門rt:rt用於控制前一時刻隱層單元ht-1對當前詞xt的影響。若是ht-1對xt不重要,即從當前詞xt開始表述了新的意思,與上文無關, 那麼rt開關能夠打開, 使得ht-1對xt不產生影響。
更新門zt:zt用於決定是否忽略當前詞xt。相似於LSTM中的輸入門it, zt能夠判斷當前詞xt對總體意思的表達是否重要。當zt開關接通下面的支路時,咱們將忽略當前詞xt,同時構成了從ht-1到ht的」短路鏈接」,這梯度得已有效地反向傳播。和LSTM相同,這種短路機制有效地緩解了梯度消失現象, 這個機制於highwaynetworks十分類似。
儘管RNN, LSTM,和GRU的網絡結構差異很大,可是他們的基本計算單元是一致的,都是對xt和ht-1作一個線性映射加tanh激活函數,見三個圖的紅色框部分。他們的區別在於如何設計額外的門控機制控制梯度信息傳播用以緩解梯度消失現象。LSTM用了3個門,GRU用了2個,那能不能再少呢? MGU (minimal gate unit)嘗試對這個問題作出回答, 它只有一個門控單元。
最後留個小練習, 參考LSTM和GRU的例子,你能不能用「三次簡化一張圖」的方法來分析一下MGU呢?
參考文獻
1. Bengio, Yoshua, PatriceSimard, and Paolo Frasconi。 "Learning long-term dependencies with gradient descent isdifficult。" IEEE transactions on neural networks 5。2 (1994):157-166。
2. Cho, Kyunghyun, et al。"Learning phrase representations using RNN encoder-decoder for statisticalmachine translation。" arXiv preprint arXiv:1406。1078 (2014)。
3. Chung, Junyoung, et al。"Empirical evaluation of gated recurrent neural networks on sequencemodeling。" arXiv preprint arXiv:1412。3555 (2014)。
4. Gers, Felix。 "Longshort-term memory in recurrent neural networks。" UnpublishedPhD dissertation, Ecole Polytechnique Fédérale de Lausanne, Lausanne, Switzerland(2001)。
5. Goodfellow, Ian, YoshuaBengio, and Aaron Courville。 Deep learning。 MIT press, 2016。
6. Graves, Alex。 Supervisedsequence labelling with recurrent neural networks。 Vol。 385。 Heidelberg:Springer, 2012。
7. Greff, Klaus, et al。 "LSTM:A search space odyssey。" IEEE transactions on neural networks and learning systems(2016)。
8. He, Kaiming, et al。 "Deepresidual learning for image recognition。" Proceedingsof the IEEE conference on computer vision and pattern recognition。 2016。
9. He, Kaiming, et al。"Identity mappings in deep residual networks。" EuropeanConference on Computer Vision。 Springer International Publishing, 2016。
10. Hochreiter, Sepp, and JürgenSchmidhuber。 "Long short-term memory。" Neuralcomputation 9。8 (1997): 1735-1780。
11. Jozefowicz, Rafal, WojciechZaremba, and Ilya Sutskever。 "An empirical exploration of recurrent network architectures。" Proceedingsof the 32nd International Conference on Machine Learning (ICML-15)。 2015。
12. Li, Fei-Fei, JustinJohnson, and Serena Yeung。 CS231n: Convolutional Neural Networks for Visual Recognition。 Stanford。 2017。
13. Lipton, Zachary C。, JohnBerkowitz, and Charles Elkan。 "A critical review of recurrent neural networks for sequencelearning。" arXiv preprint arXiv:1506。00019 (2015)。
14. Manning, Chris andRichard Socher。 CS224n: Natural Language Processing with Deep Learning。 Stanford。 2017。
15. Pascanu, Razvan, Tomas Mikolov, and YoshuaBengio。 "On the difficulty of training recurrent neural networks。"International Conference on Machine Learning。 2013。
16. Srivastava, RupeshKumar, Klaus Greff, and Jürgen Schmidhuber。 "Highwaynetworks。" arXiv preprint arXiv:1505。00387 (2015)。
17. Williams, D。 R。 G。 H。 R。, andGeoffrey Hinton。 "Learning representations by back-propagating errors。"Nature 323。6088 (1986): 533-538。
18. Zhou, Guo-Bing, et al。"Minimal gated unit for recurrent neural networks。"International Journal of Automation and Computing 13。3 (2016):226-234。
本文是投稿文章,做者:張皓
github地址:https://github.com/HaoMood/
注:AI科技大本營現已開通投稿通道,投稿請加編輯微信1092722531