tfa.seq2seq.TrainingSampler,簡單讀取輸入的訓練採樣器。
調用trainingSampler.initialize(input_tensors)時,取各batch中time_step=0的數據,拼接成一個數據集,返回。
下一次調用sampler.next_inputs函數時,會取各batch中time_step++的數據,拼接成一個數據集,返回。python
官網例子修改版:api
import tensorflow_addons as tfa import tensorflow as tf def tfa_seq2seq_TrainingSampler_test(): batch_size = 2 max_time = 3 word_vector_len = 4 hidden_size = 5 sampler = tfa.seq2seq.TrainingSampler() cell = tf.keras.layers.LSTMCell(hidden_size) input_tensors = tf.random.uniform([batch_size, max_time, word_vector_len]) initial_finished, initial_inputs = sampler.initialize(input_tensors) cell_input = initial_inputs cell_state = cell.get_initial_state(initial_inputs) for time_step in tf.range(max_time): cell_output, cell_state = cell(cell_input, cell_state) sample_ids = sampler.sample(time_step, cell_output, cell_state) finished, cell_input, cell_state = sampler.next_inputs( time_step, cell_output, cell_state, sample_ids) if tf.reduce_all(finished): break print(time_step) if __name__ == '__main__': pass; tfa_seq2seq_TrainingSampler_test()
以上面的代碼爲例,dom
# 假設輸入數值上以下所示, 輸入各維度含義, [batch_size, time_step, feature_length(或者word_vector_length)] input_tensors = tf.Tensor( [[[0.9346709 0.13170087 0.6356932 0.13167298] [0.4919318 0.44602418 0.49046385 0.28244007] [0.9263021 0.9984634 0.10324025 0.653986 ]] [[0.8260417 0.269673 0.37965262 0.86320114] [0.88838446 0.28112316 0.5868691 0.4174199 ] [0.61980057 0.2420206 0.17553246 0.9765543 ]]], shape=(2, 3, 4), dtype=float32)
當運行完sampler.initialize(input_tensors)
時,獲得以下的採樣結果,即兩個batch中,每一個batch中time_step=0的數據,拼接而成。函數
initial_inputs = tf.Tensor( [[0.9346709 0.13170087 0.6356932 0.13167298] [0.8260417 0.269673 0.37965262 0.86320114]], shape=(2, 4), dtype=float32)
第一次運行完sampler.next_inputs
時,獲得以下的採樣結果,即兩個batch中,每一個batch中time_step=1的數據,拼接而成。google
initial_inputs = tf.Tensor( [[0.4919318 0.44602418 0.49046385 0.28244007] [0.88838446 0.28112316 0.5868691 0.4174199 ]], shape=(2, 4), dtype=float32)
第二次運行完sampler.next_inputs
時,獲得以下的採樣結果,即兩個batch中,每一個batch中time_step=2的數據,拼接而成。code
initial_inputs = tf.Tensor( [[0.9263021 0.9984634 0.10324025 0.653986 ] [0.61980057 0.2420206 0.17553246 0.9765543 ]], shape=(2, 4), dtype=float32)
sample_ids的含義,RNN輸出,每一批中,數值最大的邏輯位對應的下標。orm
# 當LSTMCell的輸出以下所示時, cell_output = tf.Tensor( [[-0.07552935 0.07034459 0.12033001 -0.1792231 0.05634112] [-0.10488522 0.06370427 0.17486209 -0.10092633 0.09584342]], shape=(2, 5), dtype=float32) # 顯然,第一批與第二批中都是下標=2的邏輯位數值最大 sample_ids = tf.Tensor([2 2], shape=(2,), dtype=int32)
https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/Sampler?hl=zh-cn (tfa.seq2seq.Sampler | TensorFlow Addons)
https://tensorflow.google.cn/addons/api_docs/python/tfa/seq2seq/TrainingSampler (tfa.seq2seq.TrainingSampler | TensorFlow Addons)get