如何在tensorflow中實現雙向rnn
python
tensorflow
中已經提供了雙向rnn
的接口,它就是tf.nn.bidirectional_dynamic_rnn()
. 咱們先來看一下這個接口怎麼用.git
1 bidirectional_dynamic_rnn( 2 cell_fw, #前向 rnn cell 3 cell_bw, #反向 rnn cell 4 inputs, #輸入序列. 5 sequence_length=None,# 序列長度 6 initial_state_fw=None,#前向rnn_cell的初始狀態 7 initial_state_bw=None,#反向rnn_cell的初始狀態 8 dtype=None,#數據類型 9 parallel_iterations=None, 10 swap_memory=False, 11 time_major=False, 12 scope=None 13 )
返回值:一個tuple(outputs, outputs_states), 其中,outputs
是一個tuple(outputs_fw, outputs_bw). 關於outputs_fw
和outputs_bw
,若是time_major=True
則它倆也是time_major
的,vice versa. 若是想要concatenate
的話,直接使用tf.concat(outputs, 2)
便可.github
如何使用:
bidirectional_dynamic_rnn 在使用上和 dynamic_rnide
n是很是類似的. 定義前向和反向rnn_cell 定義前向和反向rnn_cell的初始狀態 準備好序列 調用bidirectional_dynamic_rnn import tensorflow as tf from tensorflow.contrib import rnn cell_fw = rnn.LSTMCell(10) cell_bw = rnn.LSTMCell(10) initial_state_fw = cell_fw.zero_state(batch_size) initial_state_bw = cell_bw.zero_state(batch_size) seq = ... seq_length = ... (outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq, seq_length, initial_state_fw,initial_state_bw) out = tf.concat(outputs, 2)
# ....
單層雙向rnn能夠經過上述方法簡單的實現,可是多層的雙向rnn就不能使將MultiRNNCell
傳給bidirectional_dynamic_rnn
了.
想要知道爲何,咱們須要看一下bidirectional_dynamic_rnn
的源碼片斷.函數
1 with vs.variable_scope(scope or "bidirectional_rnn"): 2 # Forward direction 3 with vs.variable_scope("fw") as fw_scope: 4 output_fw, output_state_fw = dynamic_rnn( 5 cell=cell_fw, inputs=inputs, sequence_length=sequence_length, 6 initial_state=initial_state_fw, dtype=dtype, 7 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 8 time_major=time_major, scope=fw_scope)
這只是一小部分代碼,但足以看出,bi-rnn
其實是依靠dynamic-rnn
實現的,若是咱們使用MuitiRNNCell
的話,那麼每層之間不一樣方向之間交互就被忽略了.因此咱們能夠本身實現一個工具函數,經過屢次調用bidirectional_dynamic_rnn
來實現多層的雙向RNN 這是我對多層雙向RNN的一個精簡版的實現,若有錯誤,歡迎指出工具
上面咱們已經看到了正向過程的代碼實現,下面來看一下剩下的反向部分的實現.
其實反向的過程就是作了兩次reverse
1. 第一次reverse
:將輸入序列進行reverse
,而後送入dynamic_rnn
作一次運算.
2. 第二次reverse
:將上面dynamic_rnn
返回的outputs
進行reverse
,保證正向和反向輸出的time
是對上的.ui
1 def _reverse(input_, seq_lengths, seq_dim, batch_dim): 2 if seq_lengths is not None: 3 return array_ops.reverse_sequence( 4 input=input_, seq_lengths=seq_lengths, 5 seq_dim=seq_dim, batch_dim=batch_dim) 6 else: 7 return array_ops.reverse(input_, axis=[seq_dim]) 8 9 with vs.variable_scope("bw") as bw_scope: 10 inputs_reverse = _reverse( 11 inputs, seq_lengths=sequence_length, 12 seq_dim=time_dim, batch_dim=batch_dim) 13 tmp, output_state_bw = dynamic_rnn( 14 cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, 15 initial_state=initial_state_bw, dtype=dtype, 16 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 17 time_major=time_major, scope=bw_scope) 18 19 output_bw = _reverse( 20 tmp, seq_lengths=sequence_length, 21 seq_dim=time_dim, batch_dim=batch_dim) 22 23 outputs = (output_fw, output_bw) 24 output_states = (output_state_fw, output_state_bw) 25 26 return (outputs, output_states)
tf.reverse_sequencethis
對序列中某一部分進行反轉spa
1 reverse_sequence( 2 input,#輸入序列,將被reverse的序列 3 seq_lengths,#1Dtensor,表示輸入序列長度 4 seq_axis=None,# 哪維表明序列 5 batch_axis=None, #哪維表明 batch 6 name=None, 7 seq_dim=None, 8 batch_dim=None 9 )
官網上的例子給的很是好,這裏就直接粘貼過來:code
1 # Given this: 2 batch_dim = 0 3 seq_dim = 1 4 input.dims = (4, 8, ...) 5 seq_lengths = [7, 2, 3, 5] 6 7 # then slices of input are reversed on seq_dim, but only up to seq_lengths: 8 output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] 9 output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] 10 output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] 11 output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] 12 13 # while entries past seq_lens are copied through: 14 output[0, 7:, :, ...] = input[0, 7:, :, ...] 15 output[1, 2:, :, ...] = input[1, 2:, :, ...] 16 output[2, 3:, :, ...] = input[2, 3:, :, ...] 17 output[3, 2:, :, ...] = input[3, 2:, :, ...]
例二:
1 # Given this: 2 batch_dim = 2 3 seq_dim = 0 4 input.dims = (8, ?, 4, ...) 5 seq_lengths = [7, 2, 3, 5] 6 7 # then slices of input are reversed on seq_dim, but only up to seq_lengths: 8 output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] 9 output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] 10 output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] 11 output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] 12 13 # while entries past seq_lens are copied through: 14 output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] 15 output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] 16 output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] 17 output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]