包括卷積神經網絡(CNN)在內的各類前饋神經網絡模型, 其一次前饋過程的輸出只與當前輸入有關與歷史輸入無關.python
遞歸神經網絡(Recurrent Neural Network, RNN)充分挖掘了序列數據中的信息, 在時間序列和天然語言處理方面有着重要的應用.git
遞歸神經網絡能夠展開爲普通的前饋神經網絡:算法
長短時間記憶模型(Long-Short Term Memory)是RNN的經常使用實現. 與通常神經網絡的神經元相比, LSTM神經元多了一個遺忘門.網絡
LSTM神經元的輸出除了與當前輸入有關外, 還與自身記憶有關. RNN的訓練算法也是基於傳統BP算法增長了時間考量, 稱爲BPTT(Back-propagation Through Time)算法.session
tensorflow內置了遞歸神經網絡的實現:app
from tensorflow.python.ops import rnn, rnn_cell
tensorflow目前正在快速迭代中, 上述路徑可能會發生變化.在0.6.0版本中上述路徑是有效的.函數
官方教程中已經加入了循環神經網絡的部分, API可能不會發生太大變化.優化
Tensorflow有多種rnn神經元可供選擇:.net
rnn_cell.BasicLSTMCell
code
rnn_cell.LSTMCell
rnn_cell.GRUCell
這裏咱們選用最簡單的BasicLSTMCell, 須要設置神經元個數和forget_bias
參數:
self.lstm_cell = rnn_cell.BasicLSTMCell(hidden_n, forget_bias=1.0)
能夠直接調用cell對象得到輸出和狀態:
output, state = cell(inputs, state)
使用dropout避免過擬合問題:
from tensorflow.python.ops.rnn_cell import Dropoutwrapper cells = DropoutWrapper(lstm_cell, input_keep_prob=0.5, output_keep_prob=0.5)
使用MultiRNNCell來建立多層神經網絡:
from tensorflow.python.ops.rnn_cell import MultiRNNCell cells = MultiRNNCell([lstm_cell_1, lstm_cell_2])
不過rnn.rnn
能夠替咱們完成神經網絡的構建工做:
outputs, states = rnn.rnn(self.lstm_cell, self.input_layer, dtype=tf.float32)
再加一個輸出層進行輸出:
self.prediction = tf.matmul(outputs[-1], self.weights) + self.biases
定義損失函數:
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.prediction, self.label_layer))
使用Adam優化器進行訓練:
self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)
由於神經網絡須要處理序列數據, 因此輸入層略複雜:
self.input_layer = [tf.placeholder("float", [step_n, input_n]) for i in range(batch_size)]
tensorflow要求RNNCell的輸入爲一個列表, 列表中的每一項做爲一個批次進行訓練.
列表中的每個元素表明一個序列, 每一行爲序列中的一項. 這樣每一項爲一個形狀爲(序列長, 輸入維數)的矩陣.
標籤仍是和原來同樣爲形如(序列長, 輸出維度)的矩陣:
self.label_layer = tf.placeholder("float", [step_n, output_n])
執行訓練:
self.session.run(initer) for i in range(limit): self.session.run(self.trainer, feed_dict={self.input_layer[0]: train_x[0], self.label_layer: train_y})
由於input_layer
爲列表, 而列表不能做爲字典的鍵.因此咱們只能採用{self.input_layer[0]: train_x[0]}
這樣的方式輸入數據.
能夠看到lable_layer
也是二維的, 並無輸入多個批次的數據. 考慮到這兩點, 目前這個實現並不具有多批次處理的能力.
序列的長度一般是不一樣的, 而目前的實現採用的是定長輸入. 這是須要解決的另外一個難題.
完整源代碼能夠在demo.py中查看.