Tensorflow循環神經網絡

Tensorflow循環神經網絡

  • 循環神經網絡
  • 梯度消失問題
  • LSTM網絡
  • RNN其餘變種
  • 用RNN和Tensorflow實現手寫數字分類

一.循環神經網絡

from IPython.display import Image
Image(filename="./data/rnn_1.png",width=500)

output_3_0.png

RNN背後的思想就是利用順序信息.在傳統的神經網絡中,咱們假設全部輸入(或輸出)彼此獨立.但對於許多任務而言,這是一個很是糟糕的模型.若是你想預測句子中的下一個單詞,你最好知道它前面有哪些單詞.RNN對序列的每一個元素執行相同的任務,輸出取決於先前的計算.下面是典型的RNN樣子python

from IPython.display import Image
Image(filename="./data/rnn_2.png",width=500)

output_5_0.png

其中U是輸入到隱含層的權重矩陣,W是狀態到隱含層的權重矩陣,s爲狀態,V是隱含層到輸出層的權重矩陣.它的共享參數方式是各個時間節點的W,U,V都是不變的,這個機制就像卷積神經網絡的過濾器機制同樣,經過這種方法實現參數共享,同時大大下降參數量git

Image(filename="./data/rnn_3.png",width=500)

output_7_0.png

在Tensorflow中,將圖中間循環體結構叫做cell,可使用tf.nn.rnn_cell.BasicRNNCell或者tf.contrib.rnn.BasicRNNCell表達,這兩個表達僅是同一個對象的不一樣名字,沒有本質的區別.例如,tf.nn.rnn_cell.BasicRNNCell定義的參數以下代碼所示:web

tf.nn.rnn_cell.BasicRNNCell(num_units,activation=None,reuse=None,name=None)

參數說明以下:算法

  • num_units:int類型,必選參數.表示cell由多少個相似於cell的單元構成
  • activation:string類型,激活函數,默認爲tanh
  • reuse:bool類型.表明是否從新使用scope中的參數
  • name:string類型.名稱

二.前向傳播與隨時間方向傳播

1.RNN前向傳播

Image(filename="./data/rnn_example.png",width=500)

output_13_0.png

import numpy as np

X=[1,2]
state=[0.0,0.0]
w_cell_state=np.asarray([[0.1,0.2],[0.3,0.4],[0.5,0.6]])
b_cell=np.asarray([0.1,-0.1])
w_output=np.asarray([[1.0],[2.0]])
b_output=0.1

for i in range(len(X)):
    state=np.append(state,X[i])
    before_activation=np.dot(state,w_cell_state)+b_cell
    state=np.tanh(before_activation)
    final_output=np.dot(state,w_output)+b_output
    print("狀態值_%i"%i,state)
    print("輸出值-%i"%i,final_output)
狀態值_0 [0.53704957 0.46211716]
輸出值-0 [1.56128388]
狀態值_1 [0.85973818 0.88366641]
輸出值-1 [2.72707101]

2.RNN隨時間反向傳播

RNN的反向傳播訓練算法稱爲隨時間反向傳播(Backpropagation Through Time,BPTT)算法,其基本原理和反向傳播算法是同樣的,只不過反向傳播算法是按照層進行反向傳播的,而BPTT是按照時間進行反向傳播的網絡

三.梯度消失或爆炸

在實際應用中,上述介紹的標準循環神經網絡訓練的優化算法面臨一個很大的難題,就是長期依賴問題.因爲網絡結構變深,使得模型喪失了學習先前信息的能力.通俗地講,標準的循環神經網絡雖然有了記憶,但很健忘.循環神經網絡其實是在長時間序列的各個時刻重複應用相同操做來構建很是深的計算圖,而且模型參數共享,這讓問題變得更加凸顯.例如,W是一個在時間步中反覆被用於相乘的矩陣,舉個簡單狀況,比方說W能夠由特徵值分解

session

所以很容易看出:

數據結構

當特徵值r_i不在1附近時,若在量級上大於1則會爆炸;若小於1則會消失.這即是著名的梯度消失或爆炸問題(vanishing and exploding gradient problem).梯度的消失使得咱們難以知道參數朝哪一個方向移動能改進代價函數,而梯度的爆炸會使學習過程變得不穩定app

實際上梯度消失或爆炸問題應該是深度學習中的一個基本問題,在任何深度神經網絡中均可能存在,而不只是循環神經網絡所獨有.在RNN中,相鄰時間步是鏈接在一塊兒的,所以它們的權重偏導數要麼都小於1,要麼都大於1,RNN中每一個權重都會向相同的反方向變化,這樣與前饋神經網絡相比,RNN的梯度消失或爆炸會更加明顯dom

如何避免梯度消失或爆炸問題?目前最流行的一種解決方案稱爲長短時記憶網絡(Long Short-Term Memory,LSTM),還有基於LSTM的幾種變種算法,如GRU(Gated Recurrent Unit,GRU)算法等svg

三.LSTM算法

LSTM可以有效解決信息的長期依賴,避免梯度消失或爆炸.事實上,LSTM的設計就是專門用於解決長期依賴問題的.與傳統RNN相比,它在結構上的獨特之處是它精巧地設計了循環體結構.LSTM用兩個門來控制單元狀態h的內容:一個是遺忘門(forget gate),它決定了上一時刻的單元狀態h_t-1有多少保留到當前時刻c_t;另外一個是輸入門(input gate),它決定了當前時刻網絡的輸入x_t有多少保存到單元狀態c_t.LSTM用輸出門(output gate)來控制單元狀態h_t有多少輸出到LSTM的當前輸出值h

Image(filename="./data/LSTM.png",width=500)

output_25_0.png

Image(filename="./data/LSTM2.png",width=500)

output_26_0.png

LSTM對神經元狀態的修改是經過一種叫"門"的結構完成的,門使得信息能夠有選擇性地經過.LSTM中門是由一個sigmoid函數和一個按位乘積運算元件構成的

Image(filename="./data/LSTM3-gate.png",width=500)

output_28_0.png

sigmoid函數使得其輸出結果在0到1之間,sigmoid的輸出結果爲0時,則不容許任何信息經過;sigmoid爲1時則容許所有信息經過;sigmoid的輸出位於(0,1)之間時,則容許部分信息經過.LSTM有三個這樣門結構,即輸入門,遺忘門和輸出門,用來保護和控制神經元狀態的改變

與標準的RNN同樣,在Tensorflow中,LSTM的循環結構也有較好的封裝類,有tf.nn.rnn_cell.BasicLSTMCelltf.contrib.rnn.BasicLSTMCell,二者功能相同,使用參數也徹底一致

tf.nn.rnn_cell.BasicLSTMCell(num_units,forget_bias=1.0,state_is_tuple=True,activation=None,reuse=None,name=None)

參數說明:

  • num_units:int,表示LSTM cell中基本神經單元的個數
  • forget_bias:float,默認爲1,遺忘門中的bias添加項
  • activation:string,內部狀態的激活函數,默認爲tanh
  • reuse:bool,可選參數,默認爲True,決定是否重用當前變量scope中的變量
  • name:string,可選參數,默認爲None.指layer的名稱,相同名稱的層會共享變量,使用時應注意與reuse配合

如下代碼完成了一個簡單的LSTM網絡結構的構建

import tensorflow as tf

num_units=128
num_layers=2
batch_size=100

# 建立一個BasicLSTMCell,即LSTM循環體
# num_units爲循環體中基本單元的個數,數量越多,網絡的特徵表達能力越強

rnn_cell=tf.contrib.rnn.BasicLSTMCell(num_units)

# 使用多層結構,返回值仍然爲cell結構
if num_layers>=2:
    rnn_cell=tf.nn.rnn_cell.MultiRNNCell([rnn_cell]*num_layers)
    
# 定義初始化狀態
initial_state=rnn_cell.zero_state(batch_size,dtype=tf.float32)

# 定義輸入數據結構以完成循環神經網絡的構建
outputs,state=tf.nn.dynamic_rnn(rnn_cell,input_data,initial_state=initial_state,dtype=tf.float32)

# outputs是一個張量,其形狀爲[batch_size,max_time,cell_state_size]
# state是一個張量,其形狀爲[batch_size,cell_state_size]

五.RNN其餘變種

1.GRU

RNN的改進版LSTM,它有效克服了傳統RNN的一些不足,比較好地解決了梯度消失,長期依賴等問題.不過LSTM也有一些不足,如結構比較複雜,計算複雜度較高.GRU對LSTM作了不少簡化,比LSTM少一個Gate,所以計算效率更高

2.Bi-RNN

RNN能夠處理不固定長度時序數據,沒法利用將來信息.Bi-RNN同時使用時序數據輸入歷史及將來數據,時序相反時兩個循環神經網絡鏈接同一輸出,輸出層能夠同時獲取歷史將來信息

Image(filename="./data/Bi-RNN.png",width=500)

output_40_0.png

雙向循環神經網絡的基本思想是:每個訓練序列向前和向後分別是兩個循環神經網絡(RNN),並且這兩個都鏈接着一個輸出層.這個結果提供給輸出層輸入序列中每個點完整的過去和將來的上下文信息.六個獨特的權值在每個時步被重複利用,六個權值分別對應着輸入到向前和向後隱含層(w1,w3),隱含層到隱含層本身(w2,w5),向前和向後隱含層到輸出層(w4,w6).值得注意的是,向前和向後隱含層之間沒有信息流,這保證了展開圖是非循環的

六.RNN應用場景

RNN網絡適合於處理序列數據,序列長度通常不是固定的

Image(filename="./data/RNN_4.png",width=500)

output_44_0.png

七.用LSTM實現分類

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt

tf.set_random_seed(1)
np.random.seed(1)

# 定義超參數
BATCH_SIZE=64
TIME_STEP=28
INPUT_SIZE=28
LR=0.1

#讀人數據
mnist=input_data.read_data_sets("./data/mnist/",one_hot=True)
test_x=mnist.test.images[:2000]
test_y=mnist.test.labels[:2000]

# 畫出一張圖片觀察一下
print(mnist.train.images.shape)
print(mnist.train.labels.shape)
plt.imshow(mnist.train.images[0].reshape((28,28)),cmap="gray")
plt.title("%i"%np.argmax(mnist.train.labels[0]))
plt.show()

# 定義表示x向量的tensorflow.placeholder
tf_x=tf.placeholder(tf.float32,[None,TIME_STEP*INPUT_SIZE])
image=tf.reshape(tf_x,[-1,TIME_STEP,INPUT_SIZE])

# 定義表示 y 向量的placeholder
tf_y=tf.placeholder(tf.int32,[None,10])

# RNN 的循環體結構,使用 LSTM
rnn_cell=tf.contrib.rnn.BasicLSTMCell(num_units=64)

outputs,(h_c,h_n)=tf.nn.dynamic_rnn(
    rnn_cell,
    image,
    initial_state=None,
    dtype=tf.float32,
    time_major=False,
)

output=tf.layers.dense(outputs[:,-1,:],10)

loss=tf.losses.softmax_cross_entropy(onehot_labels=tf_y,logits=output)

train_op=tf.train.AdamOptimizer(LR).minimize(loss)

# 預測精度
accuracy=tf.metrics.accuracy(labels=tf.argmax(tf_y,axis=1),predictions=tf.argmax(output,axis=1),)[1]

session=tf.Session()

init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())

session.run(init_op)

for step in range(1200):
    b_x,b_y=mnist.train.next_batch(BATCH_SIZE)
    _,loss_=session.run([train_op,loss],{tf_x:b_x,tf_y:b_y})
    if step%50==0:
        accuracy_=session.run(accuracy,{tf_x:test_x,tf_y:test_y})
        print("Train loss:%.4f"%loss_,"| Test accuracy:%.2f"%accuracy_)
        
#輸出測試集中的十個預測結果
test_output=session.run(output,{tf_x:test_x[:10]})
pred_y=np.argmax(test_output,1)
print(pred_y,"prediction number")
print(np.argmax(test_y[:10],1),"real number")
(55000, 784)
(55000, 10)

<Figure size 640x480 with 1 Axes>

WARNING:tensorflow:From <ipython-input-1-571f25e9d86b>:35: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
WARNING:tensorflow:From <ipython-input-1-571f25e9d86b>:42: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
WARNING:tensorflow:From E:\Anaconda\envs\mytensorflow\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:162: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From <ipython-input-1-571f25e9d86b>:45: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.
WARNING:tensorflow:From E:\Anaconda\envs\mytensorflow\lib\site-packages\tensorflow\python\ops\losses\losses_impl.py:209: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
Train loss:2.3271 | Test accuracy:0.10
Train loss:1.5117 | Test accuracy:0.26
Train loss:1.0649 | Test accuracy:0.38
Train loss:0.9420 | Test accuracy:0.46
Train loss:0.4853 | Test accuracy:0.51
Train loss:0.5550 | Test accuracy:0.56
Train loss:0.5579 | Test accuracy:0.59
Train loss:0.4607 | Test accuracy:0.61
Train loss:0.7924 | Test accuracy:0.63
Train loss:0.6559 | Test accuracy:0.65
Train loss:0.3282 | Test accuracy:0.66
Train loss:0.5112 | Test accuracy:0.67
Train loss:0.5746 | Test accuracy:0.68
Train loss:0.5322 | Test accuracy:0.69
Train loss:0.6505 | Test accuracy:0.70
Train loss:0.6257 | Test accuracy:0.71
Train loss:0.5075 | Test accuracy:0.71
Train loss:0.9287 | Test accuracy:0.72
Train loss:0.4810 | Test accuracy:0.73
Train loss:0.3196 | Test accuracy:0.73
Train loss:0.5486 | Test accuracy:0.74
Train loss:0.6377 | Test accuracy:0.74
Train loss:0.6051 | Test accuracy:0.74
Train loss:0.6930 | Test accuracy:0.74
[7 2 1 0 4 1 4 7 6 9] prediction number
[7 2 1 0 4 1 4 9 5 9] real number
相關文章
相關標籤/搜索