最近在看CS224d,這裏主要介紹LSTM(Long Short-Term Memory)的推導過程以及用Python進行簡單的實現。LSTM是一種時間遞歸神經網絡,是RNN的一個變種,很是適合處理和預測時間序列中間隔和延遲很是長的事件。假設咱們去試着預測‘I grew up in France...(很長間隔)...I speak fluent French’最後的單詞,當前的信息建議下一個此多是一種語言的名字(由於speak嘛),可是要準確預測出‘French’咱們就須要前面的離當前位置較遠的‘France’做爲上下文,當這個間隔比較大的時候RNN就會難以處理,而LSTM則沒有這個問題。node
爲了弄明白LSTM的實現,我下載了alex的原文,可是被論文上圖片和公式弄的暈頭轉向,無奈最後在網上收集了一些資料才總算弄明白。我這裏不介紹就LSTM的前置RNN了,不懂的童鞋本身瞭解一下吧。python
首先看一張LSTM節點的內部示意圖:
圖片來自一篇講解LSTM的blog(http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
這是我認爲網上畫的最好的LSTM網絡節點圖(比論文裏面畫的容易理解多了),LSTM前向過程就是看圖說話,關鍵的函數節點已經在圖中標出,這裏咱們忽略了其中一個tanh計算過程。git
\[ \begin{eqnarray} g(t) &=& \phi(W_{gx}x(t) + W_{gh}h(t-1) + b_{g} \\ i(t) &=& \sigma(W_{ix}x(t) + W_{ih}h(t-1) + b_{i} \\ f(t) &=& \sigma(W_{fx}x(t) + W_{fh}h(t-1) + b_{f} \\ o(t) &=& \sigma(W_{ox}x(t) + W_{oh}h(t-1) + b_{o} \\ s(t) &=& g(t)*i(t) + s(t-1)*f(t) \\ h(t) &=& s(t) * o(t) \end{eqnarray} \]github
這裏\(\phi(x)=tanh(x),\sigma(x)=\frac{1}{1+e^{-x}}\),\(x(t),h(t)\)分別是咱們的輸入序列和輸出序列。若是咱們把\(x(t)\)與\(h(t-1)\)這兩個向量進行合併:網絡
\[ x_c(t)=[x(t),h(t-1)] \]
那麼能夠上面的方程組能夠重寫爲:app
\[ \begin{eqnarray} g(t) &=& \phi(W_{g}x_c(t)) + b_{g} \\ i(t) &=& \sigma(W_{i}x_c(t)) + b_{i} \\ f(t) &=& \sigma(W_{f}x_c(t)) + b_{f} \\ o(t) &=& \sigma(W_{o}x_c(t)) + b_{o} \\ s(t) &=& g(t)*i(t) + s(t-1)*f(t) \\ h(t) &=& s(t) * o(t) \end{eqnarray} \]dom
其中\(f(t)\)被稱爲忘記門,所表達的含義是決定咱們會從之前狀態中丟棄什麼信息。\(i(t),g(t)\)構成了輸入門,決定什麼樣的新信息被存放在細胞狀態中。\(o(t)\)所在位置被稱做輸出門,決定咱們要輸出什麼值。這裏表述的不是很準確,感興趣的讀者能夠去http://colah.github.io/posts/2015-08-Understanding-LSTMs/ NLP這塊我也不太懂。ide
前向過程的代碼以下:函數
def bottom_data_is(self, x, s_prev = None, h_prev = None): # if this is the first lstm node in the network if s_prev == None: s_prev = np.zeros_like(self.state.s) if h_prev == None: h_prev = np.zeros_like(self.state.h) # save data for use in backprop self.s_prev = s_prev self.h_prev = h_prev # concatenate x(t) and h(t-1) xc = np.hstack((x, h_prev)) self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg) self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi) self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf) self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo) self.state.s = self.state.g * self.state.i + s_prev * self.state.f self.state.h = self.state.s * self.state.o self.x = x self.xc = xc
LSTM的正向過程比較容易,反向過程則比較複雜,咱們先定義一個loss function \(l(t)=f(h(t),y(t)))=||h(t)-y(t)||^2\),\(h(t),y(t)\)分別爲輸出序列與樣本標籤,咱們要作的就是最小化整個時間序列上的\(l(t)\),即最小化post
\[ L=\sum_{t=1}^{T}l(t) \]
其中\(T\)表明整個時間序列,下面咱們經過\(L\)來計算梯度,假設咱們要計算\(\frac{dL}{dw}\),其中\(w\)是一個標量(例如是矩陣\(W_{gx}\)的一個元素),由鏈式法則能夠導出
\[ \frac{dL}{dw} = \sum_{t=1}^{T}\sum_{i=1}^{M}\frac{dL}{dh_i(t)}\frac{dh_i(t)}{dw} \]
其中\(h_i(t)\)是第i個單元的輸出,\(M\)是LSTM單元的個數,網絡隨着時間t前向傳播,\(h_i(t)\)的改變不影響t時刻以前的loss,咱們能夠寫出:
\[ \frac{dL}{dh_i(t)} = \sum_{s=1}^{T}\frac{dl(s)}{dh_i(t)} = \sum_{s=t}^{T}\frac{dl(s)}{dh_i(t)} \]
爲了書寫方便咱們令\(L(t)=\sum_{s=t}^{T}l(s)\)來簡化咱們的書寫,這樣\(L(1)\)就是整個序列的loss,重寫上式有:
\[ \frac{dL}{dh_i(t)} = \sum_{s=1}^{T}\frac{dl(s)}{dh_i(t)} = \frac{dL(t)}{dh_i(t)} \]
這樣咱們就能夠將梯度重寫爲:
\[ \frac{dL}{dw} = \sum_{t=1}^{T}\sum_{i=1}^{M}\frac{dL(t)}{dh_i(t)}\frac{dh_i(t)}{dw} \]
咱們知道\(L(t)=l(t)+L(t+1)\),那麼\(\frac{dL(t)}{dh_i(t)}=\frac{dl(t)}{dh_i(t)} + \frac{dL(t+1)}{dh_i(t)}\),這說明獲得下一時序的導數後能夠直接得出當前時序的導數,因此咱們能夠計算\(T\)時刻的導數而後往前推,在\(T\)時刻有\(\frac{dL(T)}{dh_i(T)}=\frac{dl(T)}{dh_i(T)}\)。
def y_list_is(self, y_list, loss_layer): """ Updates diffs by setting target sequence with corresponding loss layer. Will *NOT* update parameters. To update parameters, call self.lstm_param.apply_diff() """ assert len(y_list) == len(self.x_list) idx = len(self.x_list) - 1 # first node only gets diffs from label ... loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) # here s is not affecting loss due to h(t+1), hence we set equal to zero diff_s = np.zeros(self.lstm_param.mem_cell_ct) self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 ### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h ### we also propagate error along constant error carousel using diff_s while idx >= 0: loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 return loss
從上面公式能夠很容易理解diff_h的計算過程。這裏的loss_layer.bottom_diff定義以下:
def bottom_diff(self, pred, label): diff = np.zeros_like(pred) diff[0] = 2 * (pred[0] - label) return diff
該函數結合上文的loss function很明顯。下面來推導\(\frac{dL(t)}{ds(t)}\),結合前面的前向公式咱們能夠很容易得出\(s(t)\)的變化會直接影響\(h(t)\)和\(h(t+1)\),進而影響\(L(t)\),即有:
\[ \frac{dL(t)}{dh_i(t)}=\frac{dL(t)}{dh_i(t)}*\frac{dh_i(t)}{ds_i(t)} + \frac{dL(t)}{dh_i(t+1)}*\frac{dh_i(t+1)}{ds_i(t)} \]
由於\(h(t+1)\)不影響\(l(t)\)因此有\(\frac{dL(t)}{dh_i(t+1)}=\frac{dL(t+1)}{dh_i(t+1)}\),所以有:
\[ \frac{dL(t)}{dh_i(t)}=\frac{dL(t)}{dh_i(t)}*\frac{dh_i(t)}{ds_i(t)} + \frac{dL(t+1)}{dh_i(t+1)}*\frac{dh_i(t+1)}{ds_i(t)}=\frac{dL(t)}{dh_i(t)}*\frac{dh_i(t)}{ds_i(t)} + \frac{dL(t+1)}{ds_i(t)} \]
一樣的咱們能夠經過後面的導數逐級反推獲得前面的導數,代碼即diff_s的計算過程。
下面咱們計算\(\frac{dL(t)}{dh_i(t)}*\frac{dh_i(t)}{ds_i(t)}\),由於\(h(t)=s(t)*o(t)\),那麼\(\frac{dL(t)}{dh_i(t)}*\frac{dh_i(t)}{ds_i(t)}=\frac{dL(t)}{dh_i(t)}*o_i(t)=o_i(t)[diff\_h]\),即\(\frac{dL(t)}{ds_i(t)}=o(t)[diff\_h]_i+[diff\_s]_i\),其中\([diff\_h]_i,[diff\_s]_i\)分別表述當前t時序的\(\frac{dL(t)}{dh_i(t)}\)和t+1時序的\(\frac{dL(t)}{ds_i(t)}\)。一樣的,結合上面的代碼應該比較容易理解。
下面咱們根據前向過程挨個計算導數:
\[ \begin{eqnarray} \frac{dL(t)}{do(t)}&=&\frac{dL(t)}{dh(t)}*s(t) \\ \frac{dL(t)}{di(t)}&=&\frac{dL(t)}{ds(t)}*\frac{ds(t)}{di(t)}=\frac{dL(t)}{ds(t)}*g(t) \\ \frac{dL(t)}{dg(t)}&=&\frac{dL(t)}{ds(t)}*\frac{ds(t)}{dg(t)}=\frac{dL(t)}{ds(t)}*i(t) \\ \frac{dL(t)}{df(t)}&=&\frac{dL(t)}{ds(t)}*\frac{ds(t)}{df(t)}=\frac{dL(t)}{ds(t)}*s(t-1) \\ \end{eqnarray} \]
所以有如下代碼:
def top_diff_is(self, top_diff_h, top_diff_s): # notice that top_diff_s is carried along the constant error carousel ds = self.state.o * top_diff_h + top_diff_s do = self.state.s * top_diff_h di = self.state.g * ds dg = self.state.i * ds df = self.s_prev * ds # diffs w.r.t. vector inside sigma / tanh function di_input = (1. - self.state.i) * self.state.i * di #sigmoid diff df_input = (1. - self.state.f) * self.state.f * df do_input = (1. - self.state.o) * self.state.o * do dg_input = (1. - self.state.g ** 2) * dg #tanh diff # diffs w.r.t. inputs self.param.wi_diff += np.outer(di_input, self.xc) self.param.wf_diff += np.outer(df_input, self.xc) self.param.wo_diff += np.outer(do_input, self.xc) self.param.wg_diff += np.outer(dg_input, self.xc) self.param.bi_diff += di_input self.param.bf_diff += df_input self.param.bo_diff += do_input self.param.bg_diff += dg_input # compute bottom diff dxc = np.zeros_like(self.xc) dxc += np.dot(self.param.wi.T, di_input) dxc += np.dot(self.param.wf.T, df_input) dxc += np.dot(self.param.wo.T, do_input) dxc += np.dot(self.param.wg.T, dg_input) # save bottom diffs self.state.bottom_diff_s = ds * self.state.f self.state.bottom_diff_x = dxc[:self.param.x_dim] self.state.bottom_diff_h = dxc[self.param.x_dim:]
這裏top_diff_h,top_diff_s分別是上文的diff_h,diff_s。這裏咱們講解下wi_diff的求解過程,其餘變量相似。
\[ \frac{dL(t)}{dW_i} = \frac{dL(t)}{di(t)}*\frac{di(t)}{d(W_ix_c(t))}*\frac{d(W_ix_c(t))}{dx_c(t)} \]
上式化簡以後即獲得如下代碼
wi_diff += np.outer((1.-i)*i*di, xc)
其它的導數能夠一樣獲得,這裏就不贅述了。
#lstm在輸入一串連續質數時預估下一個質數 import random import numpy as np import math def sigmoid(x): return 1. / (1 + np.exp(-x)) # createst uniform random array w/ values in [a,b) and shape args def rand_arr(a, b, *args): np.random.seed(0) return np.random.rand(*args) * (b - a) + a class LstmParam: def __init__(self, mem_cell_ct, x_dim): self.mem_cell_ct = mem_cell_ct self.x_dim = x_dim concat_len = x_dim + mem_cell_ct # weight matrices self.wg = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) self.wi = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) self.wf = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) self.wo = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) # bias terms self.bg = rand_arr(-0.1, 0.1, mem_cell_ct) self.bi = rand_arr(-0.1, 0.1, mem_cell_ct) self.bf = rand_arr(-0.1, 0.1, mem_cell_ct) self.bo = rand_arr(-0.1, 0.1, mem_cell_ct) # diffs (derivative of loss function w.r.t. all parameters) self.wg_diff = np.zeros((mem_cell_ct, concat_len)) self.wi_diff = np.zeros((mem_cell_ct, concat_len)) self.wf_diff = np.zeros((mem_cell_ct, concat_len)) self.wo_diff = np.zeros((mem_cell_ct, concat_len)) self.bg_diff = np.zeros(mem_cell_ct) self.bi_diff = np.zeros(mem_cell_ct) self.bf_diff = np.zeros(mem_cell_ct) self.bo_diff = np.zeros(mem_cell_ct) def apply_diff(self, lr = 1): self.wg -= lr * self.wg_diff self.wi -= lr * self.wi_diff self.wf -= lr * self.wf_diff self.wo -= lr * self.wo_diff self.bg -= lr * self.bg_diff self.bi -= lr * self.bi_diff self.bf -= lr * self.bf_diff self.bo -= lr * self.bo_diff # reset diffs to zero self.wg_diff = np.zeros_like(self.wg) self.wi_diff = np.zeros_like(self.wi) self.wf_diff = np.zeros_like(self.wf) self.wo_diff = np.zeros_like(self.wo) self.bg_diff = np.zeros_like(self.bg) self.bi_diff = np.zeros_like(self.bi) self.bf_diff = np.zeros_like(self.bf) self.bo_diff = np.zeros_like(self.bo) class LstmState: def __init__(self, mem_cell_ct, x_dim): self.g = np.zeros(mem_cell_ct) self.i = np.zeros(mem_cell_ct) self.f = np.zeros(mem_cell_ct) self.o = np.zeros(mem_cell_ct) self.s = np.zeros(mem_cell_ct) self.h = np.zeros(mem_cell_ct) self.bottom_diff_h = np.zeros_like(self.h) self.bottom_diff_s = np.zeros_like(self.s) self.bottom_diff_x = np.zeros(x_dim) class LstmNode: def __init__(self, lstm_param, lstm_state): # store reference to parameters and to activations self.state = lstm_state self.param = lstm_param # non-recurrent input to node self.x = None # non-recurrent input concatenated with recurrent input self.xc = None def bottom_data_is(self, x, s_prev = None, h_prev = None): # if this is the first lstm node in the network if s_prev == None: s_prev = np.zeros_like(self.state.s) if h_prev == None: h_prev = np.zeros_like(self.state.h) # save data for use in backprop self.s_prev = s_prev self.h_prev = h_prev # concatenate x(t) and h(t-1) xc = np.hstack((x, h_prev)) self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg) self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi) self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf) self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo) self.state.s = self.state.g * self.state.i + s_prev * self.state.f self.state.h = self.state.s * self.state.o self.x = x self.xc = xc def top_diff_is(self, top_diff_h, top_diff_s): # notice that top_diff_s is carried along the constant error carousel ds = self.state.o * top_diff_h + top_diff_s do = self.state.s * top_diff_h di = self.state.g * ds dg = self.state.i * ds df = self.s_prev * ds # diffs w.r.t. vector inside sigma / tanh function di_input = (1. - self.state.i) * self.state.i * di df_input = (1. - self.state.f) * self.state.f * df do_input = (1. - self.state.o) * self.state.o * do dg_input = (1. - self.state.g ** 2) * dg # diffs w.r.t. inputs self.param.wi_diff += np.outer(di_input, self.xc) self.param.wf_diff += np.outer(df_input, self.xc) self.param.wo_diff += np.outer(do_input, self.xc) self.param.wg_diff += np.outer(dg_input, self.xc) self.param.bi_diff += di_input self.param.bf_diff += df_input self.param.bo_diff += do_input self.param.bg_diff += dg_input # compute bottom diff dxc = np.zeros_like(self.xc) dxc += np.dot(self.param.wi.T, di_input) dxc += np.dot(self.param.wf.T, df_input) dxc += np.dot(self.param.wo.T, do_input) dxc += np.dot(self.param.wg.T, dg_input) # save bottom diffs self.state.bottom_diff_s = ds * self.state.f self.state.bottom_diff_x = dxc[:self.param.x_dim] self.state.bottom_diff_h = dxc[self.param.x_dim:] class LstmNetwork(): def __init__(self, lstm_param): self.lstm_param = lstm_param self.lstm_node_list = [] # input sequence self.x_list = [] def y_list_is(self, y_list, loss_layer): """ Updates diffs by setting target sequence with corresponding loss layer. Will *NOT* update parameters. To update parameters, call self.lstm_param.apply_diff() """ assert len(y_list) == len(self.x_list) idx = len(self.x_list) - 1 # first node only gets diffs from label ... loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) # here s is not affecting loss due to h(t+1), hence we set equal to zero diff_s = np.zeros(self.lstm_param.mem_cell_ct) self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 ### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h ### we also propagate error along constant error carousel using diff_s while idx >= 0: loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 return loss def x_list_clear(self): self.x_list = [] def x_list_add(self, x): self.x_list.append(x) if len(self.x_list) > len(self.lstm_node_list): # need to add new lstm node, create new state mem lstm_state = LstmState(self.lstm_param.mem_cell_ct, self.lstm_param.x_dim) self.lstm_node_list.append(LstmNode(self.lstm_param, lstm_state)) # get index of most recent x input idx = len(self.x_list) - 1 if idx == 0: # no recurrent inputs yet self.lstm_node_list[idx].bottom_data_is(x) else: s_prev = self.lstm_node_list[idx - 1].state.s h_prev = self.lstm_node_list[idx - 1].state.h self.lstm_node_list[idx].bottom_data_is(x, s_prev, h_prev)
測試代碼
import numpy as np from lstm import LstmParam, LstmNetwork class ToyLossLayer: """ Computes square loss with first element of hidden layer array. """ @classmethod def loss(self, pred, label): return (pred[0] - label) ** 2 @classmethod def bottom_diff(self, pred, label): diff = np.zeros_like(pred) diff[0] = 2 * (pred[0] - label) return diff def example_0(): # learns to repeat simple sequence from random inputs np.random.seed(0) # parameters for input data dimension and lstm cell count mem_cell_ct = 100 x_dim = 50 concat_len = x_dim + mem_cell_ct lstm_param = LstmParam(mem_cell_ct, x_dim) lstm_net = LstmNetwork(lstm_param) y_list = [-0.5,0.2,0.1, -0.5] input_val_arr = [np.random.random(x_dim) for _ in y_list] for cur_iter in range(100): print "cur iter: ", cur_iter for ind in range(len(y_list)): lstm_net.x_list_add(input_val_arr[ind]) print "y_pred[%d] : %f" % (ind, lstm_net.lstm_node_list[ind].state.h[0]) loss = lstm_net.y_list_is(y_list, ToyLossLayer) print "loss: ", loss lstm_param.apply_diff(lr=0.1) lstm_net.x_list_clear() if __name__ == "__main__": example_0()
略