關於 RNN 循環神經網絡的反向傳播求導

關於 RNN 循環神經網絡的反向傳播求導

本文是對 RNN 循環神經網絡中的每個神經元進行反向傳播求導的數學推導過程,下面還使用 PyTorch 對導數公式進行編程求證。php

RNN 神經網絡架構

一個普通的 RNN 神經網絡以下圖所示:html

圖片1

其中 \(x^{\langle t \rangle}\) 表示某一個輸入數據在 \(t\) 時刻的輸入;\(a^{\langle t \rangle}\) 表示神經網絡在 \(t\) 時刻時的hidden state,也就是要傳送到 \(t+1\) 時刻的值;\(y^{\langle t \rangle}\) 則表示在第 \(t\) 時刻輸入數據傳入之後產生的預測值,在進行預測或 sampling\(y^{\langle t \rangle}\) 一般做爲下一時刻即 \(t+1\) 時刻的輸入,也就是說 \(x^{\langle t \rangle}=\hat{y}^{\langle t \rangle}\) ;下面對數據的維度進行說明。python

  • 輸入: \(x\in\mathbb{R}^{n_x\times m\times T_x}\) 其中 \(n_x\) 表示每個時刻輸入向量的長度;\(m\) 表示數據批量數(batch);\(T_x\) 表示共有多少個輸入的時刻(time step)。
  • hidden state:\(a\in\mathbb{R}^{n_a\times m\times T_x}\) 其中 \(n_a\) 表示每個 hidden state 的長度。
  • 預測:\(y\in\mathbb{R}^{n_y\times m\times T_y}\) 其中 \(n_y\) 表示預測輸出的長度;\(T_y\) 表示共有多少個輸出的時刻(time step)。

RNN 神經元

下圖所示的是一個特定的 RNN 神經元:web

圖片2

上圖說明了在第 \(t\) 時刻的神經元中,數據的輸入 \(x^{\langle t \rangle}\) 和上一層的 hidden state \(a^{\langle t \rangle}\) 是如何通過計算獲得下一層的 hidden state 和預測輸出 \(\hat{y}^{\langle t \rangle}\)編程

下面是對五個參數的維度說明:網絡

  • \(W_{aa}\in\mathbb{R}^{n_a\times n_a}\)
  • \(W_{ax}\in\mathbb{R}^{n_a\times n_x}\)
  • \(b_a\in\mathbb{R}^{n_a\times 1}\)
  • \(W_{ya}\in\mathbb{R}^{n_y\times n_a}\)
  • \(b_y\in\mathbb{R}^{n_y\times 1}\)

計算 \(t\) 時刻的 hidden state \(a^{\langle t \rangle}\)架構

\[\begin{split} z1^{\langle t \rangle} &= W_{aa} a^{\langle t-1 \rangle} + W_{ax} x^{\langle t \rangle} + b_a\\ a^{\langle t \rangle} &= \tanh(z1^{\langle t \rangle}) \end{split} \]

預測 \(t\) 時刻的輸出 \(\hat{y}^{\langle t \rangle}\)框架

\[\begin{split} z2^{\langle t \rangle} &= W_{ya} a^{\langle t \rangle} + b_y\\ \hat{y}^{\langle t \rangle} &= softmax(z2^{\langle t \rangle}) = \frac{e^{z2^{\langle t \rangle}}}{\sum_{i=1}^{n_y}e^{z2_i^{\langle t \rangle}}} \end{split} \]

RNN 循環神經網絡反向傳播

在當今流行的深度學習編程框架中,咱們只須要編寫一個神經網絡的結構和負責神經網絡的前向傳播,至於反向傳播的求導和參數更新,徹底由框架搞定;即使如此,咱們在學習階段也要本身動手證實一下反向傳播的有效性。ide

RNN 神經元的反向傳播

下圖是 RNN 神經網絡中的一個基本的神經元,圖中標註了反向傳播所需傳來的參數和輸出等。函數

圖片3

就如一個全鏈接的神經網絡同樣,損失函數 \(J\) 的導數經過微積分的鏈式法則(chain rule)反向傳播到每個時間軸上。

爲了方便,咱們將損失函數關於神經元中參數的偏導符號簡記爲 \(\mathrm{d}\mathit{parameters}\) ;例如將 \(\frac{\partial J}{\partial W_{ax}}\) 記爲 \(\mathrm{d}W_{ax}\)

圖片4

上圖的反向傳播的實現並無包括全鏈接層和 Softmax 層。

反向傳播求導

計算損失函數關於各個參數的偏導數以前,咱們先引入一個計算圖(computation graph),其演示了一個 RNN 神經元的前向傳播和如何利用計算圖進行鏈式法則的反向求導。

image

由於當進行反向傳播求導時,咱們須要將整個時間軸的輸入所有輸入以後,才能夠從最後一個時刻開始往前傳進行反向傳播,因此咱們假設 \(t\) 時刻就爲最後一個時刻 \(T_x\)

若是咱們想要先計算 \(\frac{\partial\ell}{\partial W_{ax}}\) 因此咱們能夠從計算圖中看到,反向傳播的路徑:

image

咱們須要循序漸進的分別對從 \(W_{ax}\) 計算到 \(\ell\) 一路相關的變量進行求偏導,利用鏈式法則,將紅色路線上一路的偏導數相乘到一塊兒,就能夠求出偏導數 \(\frac{\partial\ell}{\partial W_{ax}}\) ;因此咱們獲得:

\[\begin{split} \frac{\partial\ell}{\partial W_{ax}} &= \frac{\partial\ell}{\partial\ell^{\langle t\rangle}} {\color{Red}{ \frac{\partial\ell^{\langle t\rangle}}{\partial\hat{y}^{\langle t\rangle}} \frac{\partial\hat{y}^{\langle t\rangle}}{\partial z2^{\langle t\rangle}} }} \frac{\partial z2^{\langle t\rangle}}{\partial a^{\langle t\rangle}} \frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}} \frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}} \end{split} \]

在上面的公式中,咱們僅須要分別求出每個偏導便可,其中紅色的部分就是關於 \(\mathrm{Softmax}\) 的求導,關於 \(\mathrm{Softmax}\) 求導的推導過程,能夠看本人的另外一篇博客: 關於 Softmax 迴歸的反向傳播求導數過程

關於 \(\mathrm{tanh}\) 的求導公式以下:

\[\frac{\partial \tanh(x)} {\partial x} = 1 - \tanh^2(x) \]

因此上面的式子就獲得:

\[\begin{split} \frac{\partial\ell}{\partial W_{ax}} &= \frac{\partial\ell}{\partial\ell^{\langle t\rangle}} {\color{Red}{ \frac{\partial\ell^{\langle t\rangle}}{\partial\hat{y}^{\langle t\rangle}} \frac{\partial\hat{y}^{\langle t\rangle}}{\partial z2^{\langle t\rangle}} }} \frac{\partial z2^{\langle t\rangle}}{\partial a^{\langle t\rangle}} \frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}} \frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}\\ &= {\color{Red}{ (\hat{y}^{\langle t\rangle}-y^{\langle t\rangle}) }} W_{ya} (1-\tanh^2(z1^{\langle t\rangle})) x^{\langle t\rangle} \end{split} \]

咱們就能夠獲得在最後時刻 \(t\) 參數 \(W_{ax}\) 的偏導數。

關於上面式子中的偏導數的計算,除了標量對矩陣的求導,在後面還包括了兩個一個矩陣或向量對另外一個矩陣或向量中的求導,實際上這是很是麻煩的一件事。

好比在計算 \(\frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}\) 偏導數的時候,咱們發現 \(z1^{\langle t\rangle}\) 是一個 \(\mathbb{R}^{n_a\times m}\) 的矩陣,而 \(W_{ax}\) 則是一個 \(\mathbb{R}^{n_a\times n_x}\) 的矩陣,這一項就是一個矩陣對另外一個矩陣求偏導,若是直接對其求導咱們將會獲得一個四維的矩陣 \(\mathbb{R}^{n_a\times n_x\times n_a\times m}\)雅可比矩陣 Jacobian matrix);只不過這個高維矩陣中偏導數的值有不少 \(0\)

在神經網絡中,若是直接將這個高維矩陣直接生搬硬套進梯度降低裏更新參數是不可行,由於咱們須要獲得的梯度是關於自變量同型的向量或矩陣並且咱們還要處理更高維度的矩陣的乘法;因此咱們須要將結果進行必定的處理獲得咱們僅僅須要的信息。

通常在深度學習框架中都會有自動求梯度的功能包,這些包(好比 PyTorch )中就只容許一個標量對向量或矩陣求導,其餘狀況是不容許的,除非在反向傳播的函數裏傳入一個同型的權重向量或矩陣才能夠獲得導數。

咱們先簡單求出一個偏導數 \(\frac{\partial\ell}{\partial W_{ax}}\) 咱們下面使用 PyTorch 中的自動求梯度的包進行驗證咱們的公式是否正確。

import torch
# 這是神經網絡中的一些架構的參數
n_x = 6
n_y = 6
m = 1
T_x = 5
T_y = 5
n_a = 3
# 定義全部參數矩陣
# requires_grad 爲 True 代表在涉及這個變量的運算時創建計算圖
# 爲了以後反向傳播求導
W_ax = torch.randn((n_a, n_x), requires_grad=True)
W_aa = torch.randn((n_a, n_a), requires_grad=True)
ba = torch.randn((n_a, 1), requires_grad=True)
W_ya = torch.randn((n_y, n_a), requires_grad=True)
by = torch.randn((n_y, 1), requires_grad=True)
# t 時刻的輸入和上一時刻的 hidden state
x_t = torch.randn((n_x, m), requires_grad=True)
a_prev = torch.randn((n_a, m), requires_grad=True)
y_t = torch.randn((n_y, m), requires_grad=True)
# 開始模擬一個神經元 t 時刻的前向傳播
# 從輸入一直到計算出 loss
z1_t = torch.matmul(W_ax, x_t) + torch.matmul(W_aa, a_prev) + ba
z1_t.retain_grad()
a_t = torch.tanh(z1_t)
a_t.retain_grad()
z2_t = torch.matmul(W_ya, a_t) + by
z2_t.retain_grad()
y_hat = torch.exp(z2_t) / torch.sum(torch.exp(z2_t), dim=0)
y_hat.retain_grad()
loss_t = -torch.sum(y_t * torch.log(y_hat), dim=0)
loss_t.retain_grad()
# 對最後的 loss 標量開始進行反向傳播求導
loss_t.backward()
# 咱們就能夠獲得 W_ax 的導數
# 存儲在後綴 _autograd 變量中,代表是由框架自動求導獲得的
W_ax_autograd = W_ax.grad
# 查看框架計算獲得的導數
W_ax_autograd
tensor([[ 0.5252,  1.1938, -0.2352,  1.1571, -1.0168,  0.3195],
        [-1.0536, -2.3949,  0.4718, -2.3213,  2.0398, -0.6410],
        [-0.0316, -0.0717,  0.0141, -0.0695,  0.0611, -0.0192]])
# 咱們對本身推演出的公式進行手動計算導數
# 存儲在後綴 _manugrad 變量中,代表是手動由公式計算獲得的
W_ax_manugrad = torch.matmul(torch.matmul((y_hat - y_t).T, W_ya).T * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
#torch.matmul(torch.matmul(W_ya.T, y_hat - y_t) * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
# 輸出手動計算的導數
W_ax_manugrad
tensor([[ 0.5195,  1.1809, -0.2327,  1.1447, -1.0058,  0.3161],
        [-1.0195, -2.3172,  0.4565, -2.2461,  1.9737, -0.6202],
        [-0.0309, -0.0703,  0.0138, -0.0681,  0.0599, -0.0188]],
       grad_fn=<MmBackward>)
# 查看兩種求導結果的之差的 L2 範數
torch.norm(W_ax_manugrad - W_ax_autograd)
tensor(0.1356, grad_fn=<CopyBackwards>)

經過上面的編程輸出能夠看到,咱們手動計算的導數和框架本身求出的導數雖然有必定的偏差,可是一一對照能夠大致看到咱們手動求出來的導數大致是對的,並無說錯的很是離譜。

但上面只是當 \(t=T_x\)\(t\) 時刻是最後一個輸入單元的時候,也就是說所求的關於 \(_W{ax}\) 的導數只是所有導數的一部分,由於參數共享,因此每一時刻的神經元都有對 \(W_{ax}\) 的導數,因此須要將全部時刻的神經元關於 \(W_{ax}\) 的導數所有加起來。

\(t\) 不是最後一時刻,多是神經網絡裏的中間的某一時刻的神經元;也就是說,在進行反向傳播的時候,想要求 \(t\) 時刻的導數,就得等到 \(t+1\) 時刻的導數值傳進來,而後根據鏈式法則才能夠計算當前時刻參數的導數。

下面是一個簡易的計算圖,只繪製出了 \(W_ax\)\(\ell\) 的計算中,共涉及到哪些變量(在整個神經網絡中的 \(W_{ax}\) 的權重參數是共享的):

image

下面使用一個視頻展現整個神經網絡中從 \(W_{ax}\) 到一個數據批量的損失值 \(\ell\) 的大致流向:

計算完 \(\ell\) 以後就能夠計算 \(\frac{\partial\ell}{\partial W_{ax}}\) 的導數值,可是 RNN 神經網絡的反向傳播區別於全鏈接神經網絡的。

image

而後,咱們演示一下如何進行反向傳播的,注意看每個時刻的 \(a^{\langle t\rangle}\) 的計算都是等 \(a^{\langle t+1\rangle}\) 的導數值傳進來才進行計算的;一樣地,\(W_{ax}\) 導數的計算也不是一步到位的,也是須要等到全部時刻的 \(a\) 的值所有傳到才計算完。

因此對於神經網絡中間某一個單元 \(t\) 咱們有:

\[\begin{split} \frac{\partial\ell}{\partial W_{ax}} &= {\color{Red}{ \left( \frac{\partial\ell}{\partial a^{\langle t\rangle}} +\frac{\partial\ell}{\partial z1^{\langle t+1\rangle}} \frac{\partial z1^{\langle t+1\rangle}}{\partial a^{\langle t\rangle}} \right) }} \frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}} \frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}} \end{split} \]

關於紅色的部分的意思是須要等到 \(t+1\) 時刻的導數值傳進來,而後才能夠進行對 \(t+1\) 時刻關於當前時刻 \(t\) 的參數求導,最後獲得參數梯度的一個份量。其實若仔細展開每個偏導項,就像是一個遞歸同樣,每次求某一時刻的導數老是要從最後一時刻往前傳到當前時刻才能夠進行。

多元複合函數的求導法則

若是函數 \(u=\varphi(t)\)\(v=\psi(t)\) 都在點 \(t\) 可導,函數 \(z=f(u,v)\) 在對應點 \((u,v)\) 具備連續偏導數,那麼複合函數 \(z=f[\varphi(t),\psi(t)]\) 在點 \(t\) 可導,且有

\[\frac{\mathrm{d}z}{\mathrm{d}t}=\frac{\partial z}{\partial u}\frac{\mathrm{d}u}{\mathrm{d}t}+\frac{\partial z}{\partial v}\frac{\mathrm{d}v}{\mathrm{d}t} \]

下面使用一張計算圖說明 \(a^{\langle t\rangle}\)\(\ell\) 的計算關係。

image

也就是說第 \(t\) 時刻 \(\ell\) 關於 \(a^{\langle t\rangle}\) 的導數是由兩部分相加組成,也就是說是由兩條路徑反向傳播,這兩條路徑分別是 \(\ell\to\ell^{\langle t\rangle}\to\hat{y}^{\langle t\rangle}\to z2^{\langle t\rangle}\to a^{\langle t\rangle}\)\(\ell\to\ell^{\langle t+1\rangle}\to\hat{y}^{\langle t+1\rangle}\to z2^{\langle t+1\rangle}\to a^{\langle t+1\rangle}\to z1^{\langle t+1\rangle}\to a^{\langle t\rangle}\) ,咱們將這兩條路徑導數之和使用 \(\mathrm{d}a_{\mathrm{next}}\) 表示。

因此咱們能夠獲得在中間某一時刻的神經單元關於 \(W_{ax}\) 的導數爲:

\[\frac{\partial\ell}{\partial W_{ax}}=\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t \rangle}\right)\right) x^{\langle t \rangle T} \]

經過一樣的方法,咱們就能夠獲得其它參數的導數:

\[\begin{align} \frac{\partial\ell}{\partial W_{aa}} &= \left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right) a^{\langle t-1 \rangle T}\\ \frac{\partial\ell}{\partial b_a} & = \sum_{batch}\left( da_{next} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right)\\ \end{align} \]

除了傳遞參數的導數,在第 \(t\) 時刻還須要傳送 \(\ell\) 關於 \(z1^{\langle t\rangle}\) 的導數到 \(t-1\) 時刻,將須要傳送到上一時刻的導數記做爲 \(\mathrm{d}a_{\mathrm{prev}}\) 咱們獲得:

\[\begin{split} \mathrm{d}a_{\mathrm{prev}} &= \mathrm{d}a_\mathrm{next}\frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}}\frac{\partial z1^{\langle t\rangle}}{\partial a^{\langle t-1\rangle}}\\ &= { W_{aa}}^T\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right) \end{split} \]

能夠看到,一個循環神經網絡的反向傳播其實是很是複雜的,由於每一時刻的神經元都與參數有計算關係,因此反向傳播時的路徑很是雜亂,其中還涉及到了高維的矩陣,因此在計算時須要對高維矩陣進行必定的矩陣代數轉換才方便導數和更新參數的計算。

相關文章
相關標籤/搜索