原文地址:git
https://blog.csdn.net/qq_20135597/article/details/88980975網絡
---------------------------------------------------------------------------------------------session
tensorflow中提供了rnn接口有兩種,一種是靜態的rnn,一種是動態的rnnlua
一、靜態接口:static_rnnspa
主要使用 tf.contrib.rnn.net
x = tf.placeholder("float", [None, n_steps, n_input]) x1 = tf.unstack(x, n_steps, 1) lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32) pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
靜態 rnn 的意思就是在圖中建立一個固定長度(n_steps)的網絡。這將致使調試
缺點:code
優勢:orm
模型中帶有某個序列中間臺的信息,便與調試。blog
二、動態接口:dynamic_rnn
主要使用 tf.nn.dynamic_rnn
x = tf.placeholder("float", [None, n_steps, n_input]) lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) outputs,_ = tf.nn.dynamic_rnn(lstm_cell ,x,dtype=tf.float32) outputs = tf.transpose(outputs, [1, 0, 2]) pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
動態的tf.nn.dynamic_rnn被執行時,它使用循環來動態構建圖形。這意味着
優勢:
- 圖形建立速度更快,佔用內存更少;
- 而且能夠提供可變大小的批處理。
缺點:
- 模型中只有最後的狀態。
動態rnn的意思是隻建立樣本中的一個序列RNN,其餘序列數據會經過循環進入該RNN運算
區別:
一、輸入輸出不一樣:
dynamic_rnn實現的功能就是能夠讓不一樣迭代傳入的batch能夠是長度不一樣數據,但同一次迭代一個batch內部的全部數據長度仍然是固定的。例如,第一時刻傳入的數據shape=[batch_size, 10],第二時刻傳入的數據shape=[batch_size, 12],第三時刻傳入的數據shape=[batch_size, 8]等等。
可是static_rnn不能這樣,它要求每一時刻傳入的batch數據的[batch_size, max_seq],在每次迭代過程當中都保持不變。
二、訓練方式不一樣:
具體參見參考文獻1
一、靜態多層RNN
import tensorflow as tf # 導入 MINST 數據集 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("c:/user/administrator/data/", one_hot=True) n_input = 28 # MNIST data 輸入 (img shape: 28*28) n_steps = 28 # timesteps n_hidden = 128 # hidden layer num of features n_classes = 10 # MNIST 列別 (0-9 ,一共10類) batch_size = 128 tf.reset_default_graph() # tf Graph input x = tf.placeholder("float", [None, n_steps, n_input]) y = tf.placeholder("float", [None, n_classes]) gru = tf.contrib.rnn.GRUCell(n_hidden*2) lstm_cell = tf.contrib.rnn.LSTMCell(n_hidden) mcell = tf.contrib.rnn.MultiRNNCell([lstm_cell,gru]) x1 = tf.unstack(x, n_steps, 1) outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32) pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) learning_rate = 0.001 # Define loss and optimizer cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) training_iters = 100000 display_step = 10 # 啓動session with tf.Session() as sess: sess.run(tf.global_variables_initializer()) step = 1 # Keep training until reach max iterations while step * batch_size < training_iters: batch_x, batch_y = mnist.train.next_batch(batch_size) # Reshape data to get 28 seq of 28 elements batch_x = batch_x.reshape((batch_size, n_steps, n_input)) # Run optimization op (backprop) sess.run(optimizer, feed_dict={x: batch_x, y: batch_y}) if step % display_step == 0: # 計算批次數據的準確率 acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y}) # Calculate batch loss loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y}) print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \ "{:.6f}".format(loss) + ", Training Accuracy= " + \ "{:.5f}".format(acc)) step += 1 print (" Finished!") # 計算準確率 for 128 mnist test images test_len = 100 test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input)) test_label = mnist.test.labels[:test_len] print ("Testing Accuracy:", \ sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
二、動態多層RNN
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("c:/user/administrator/data/", one_hot=True) n_input = 28 # MNIST data 輸入 (img shape: 28*28) n_steps = 28 # timesteps n_hidden = 128 # hidden layer num of features n_classes = 10 # MNIST 列別 (0-9 ,一共10類) batch_size = 128 tf.reset_default_graph() # tf Graph input x = tf.placeholder("float", [None, n_steps, n_input]) y = tf.placeholder("float", [None, n_classes]) gru = tf.contrib.rnn.GRUCell(n_hidden*2) lstm_cell = tf.contrib.rnn.LSTMCell(n_hidden) mcell = tf.contrib.rnn.MultiRNNCell([lstm_cell,gru]) outputs,states = tf.nn.dynamic_rnn(mcell,x,dtype=tf.float32)#(?, 28, 256) outputs = tf.transpose(outputs, [1, 0, 2])#(28, ?, 256) 28個時序,取最後一個時序outputs[-1]=(?,256) pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) learning_rate = 0.001 # Define loss and optimizer cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) training_iters = 100000 display_step = 10 # 啓動session with tf.Session() as sess: sess.run(tf.global_variables_initializer()) step = 1 # Keep training until reach max iterations while step * batch_size < training_iters: batch_x, batch_y = mnist.train.next_batch(batch_size) # Reshape data to get 28 seq of 28 elements batch_x = batch_x.reshape((batch_size, n_steps, n_input)) # Run optimization op (backprop) sess.run(optimizer, feed_dict={x: batch_x, y: batch_y}) if step % display_step == 0: # 計算批次數據的準確率 acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y}) # Calculate batch loss loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y}) print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \ "{:.6f}".format(loss) + ", Training Accuracy= " + \ "{:.5f}".format(acc)) step += 1 print (" Finished!") # 計算準確率 for 128 mnist test images test_len = 100 test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input)) test_label = mnist.test.labels[:test_len] print ("Testing Accuracy:", \ sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
【參考文獻】:
一、https://www.jianshu.com/p/1b1ea45fab47
二、What's the difference between tensorflow dynamic_rnn and rnn?
------------------------------------------------------------------------