前一章Doc2Vec裏提到,其實Doc2Vec只是經過加入Doc_id捕捉了文本的主題信息,並無真正考慮語序以及上下文語義,n-gram只能在局部解決這一問題,那麼還有別的解決方案麼?依舊是通用文本向量,skip-thought嘗試應用encoder-decoder來學習包含上下文信息和語序的句子向量。魔改後的實現能夠看這裏( ´▽`) github-DSXiangLi-Embedding-skip_thoughthtml
Skip-Thought顧名思義是沿用了skip-gram的路子,不熟悉的童鞋看這裏 無所不能的Embedding1 - Word2vec模型詳解&代碼實現python
skip-gram是用中間詞來預測周圍單詞,skip-Thought是用中間句子來預測前一個句子和後一個句子,模型思路就是這麼簡單粗暴,具體實現就涉及到句子的信息要如何提取,以及loss function的選擇。做者選擇了encoder-decoder來提取句子信息,用翻譯模型經常使用的log-perplrexity做爲loss。git
這裏想提一句不一樣模型,在不一樣的樣本上,訓練出的文本向量所包含的信息是不一樣的。例如word2vec的假設就是context(windo_size內周圍詞)類似的單詞更類似(向量空間距離更近)。skip-thought做者對於文本向量的假設是:能更好reconstruct先後句子的信息,就是當前句子的所含信息,換言以前後句子類似的句子,文本向量的空間距離更近。github
第一次讀到這裏感受哇make perfect sense!可越琢磨越覺着這個task有些迷幻,word2vec skip-gram能夠這麼搞,是由於給定中間詞window_size內的單詞選擇是相對有限的。你給我個句子就讓我精準預測先後句子的每個詞,這能收斂?you what?! 不着急後面彷佛有反轉~app
Encoder部分負責提取中間句子的信息生成定長向量output_state,Decoder則基於ouput_state進行迭代生成前(後)句子。Encoder-Decoder支持任意記憶單元,這裏做者選擇了GRU-GRU。框架
簡單回顧下GRU Cell,GRU有兩個Gate,從兩個角度衡量歷史sequence信息和當前token的相關程度,\(\Gamma_r\)控制多少歷史信息參與state的從新計算是reset gate,\(\Gamma_u\)控制多少歷史信息直接進入當前state是update gate,這裏安利一篇博客 Illustrated Guide to LSTM’s and GRU’s: A step by step explanationdom
Encoder部分通過GRU把長度爲T的sequence信息壓縮到hidden_size的\(h^{<T>}\),這裏\(h^{<T>}\)也是最終skip-thought爲每個句子生成的通用向量表達。ide
Decoder部分基於\(h^{<T>}\)向前預測下一個/上一個句子中的每個單詞。Decoder比Encoder略複雜,在於訓練階段和預測階段對於input的處理存在差別。函數
訓練階段使用了100%的Teacher Forcing,每一個cell的輸入除了上一個cell的hidden state,還有預測句子中前一個真實token對應的embedding,如圖工具
而在預測階段真實序列未知,所以會轉而使用前一個cell的output來預測前一個token,再用預測token的embedding做爲輸入,如圖
對於翻譯模型來講,在訓練階段使用TeacherForcing的好處是能夠加速模型收斂,避免向前迭代預測的偏差進一步放大。壞處天然是訓練和預測時decoder的表現存在差別(Exposure Bias),以及預測時decode的output會受到訓練樣本的約束。這裏最經常使用的解決方案是Scheduled Sampling, 簡單來講就是在訓練階段有P的機率輸入用teacher forcing,1-P的機率用預測output。可是!skip-thought並無使用這個解決方案,爲啥嘞?反轉來了V(^_^)V
看到無採樣的teacherforcing這裏,前面的迷惑已然解答。其實skip-thought並不僅是使用中間句子來預測先後句子,而是基於中間句子的ouput_state,用先後句子中T-1前的單詞來預測第T個單詞(感受和missing imputation只有一步之遙)。encoder部分只須要在output_state中最大程度的提取句子信息,保證在不一樣的先後句子上output state均可以generalize。至於decoder的預測部分效果如何模型並不關心,由於skip-thought的預測輸出就是encoder部分的output state,因此天然是不須要使用Scheduled Sampling
skip-thought的Decoder還有兩點特殊:
loss部分做者用了語言模型的log-perplexity把先後句子的loss加總獲得loss function
論文比較有意思的一個點還有vocabulary expansion,就是如何把word embedding擴展到訓練集以外。做者嘗試用linear-mapping的方式學習word2vec和skip-thought裏面word-embedding的映射關係,就是找到word2vec和skip-thought交集的word, 對他們的embedding作regression $ X_{word2vec} \sim W \cdot X_{skipthought} $,這樣對樣本外可是word2vec內的單詞直接用W映射就能獲得skip-thougt的詞向量
這裏直接用word2vec/glove的word embedding來初始化skip-thougt的詞向量是否是更好?在後面的模型實現裏我就是直接用word2vec來初始化了embedding, word2vec以外詞用random.uniform(-0.1,0.1)來初始化
最終在生成文本向量的時候,做者給出了幾種方案,遵循大力必定出奇蹟的原則天然方案3效果更好
這裏有點任性的對論文作了魔改。。。部分細節和論文已經天差地別,能夠拿來了解encoder-decoder的實現但不保證徹底reproduce skip-thought的結果。。。如下只保留代碼核心部分,完整代碼在 github-DSXiangLi-Embedding-skip_thought。 這裏用了tensorflow seq2seq的框架,不熟悉的童鞋能夠先看後面seq2seq的代碼解析~
論文中是\((s_{i-1}, s_i, s_{i+1})\)做爲一組樣本,其中\(s_i\)是encoder source,\(s_{i-1}\)和\(s_{i+1}\)是decoder target,這裏我直接處理成\((s_i,s_{i-1})\),\((s_i,s_{i+1})\)兩組樣本。
其中encoder source不須要多作處理,可是decoder source在Train和Eval時須要在sequence先後加入start和end_token標記序列的開始和結束,在Predict時須要加入start_token標記開始。最後經過word_table把token映射到token_id,再Padding到相同長度就齊活。
這裏在Dataset的部分加入了獲取word2vec embedding的部分, word2vec之外的單詞默認random.uniform(-0.1,0.1)
class SkipThoughtDataset(BaseDataset): def __init__(self, data_file, dict_file, epochs, batch_size, buffer_size, min_count, max_count, special_token, max_len): ... def parse_example(self, line, prepend, append): features = {} tokens = tf.string_split([tf.string_strip(line)]).values if prepend: tokens = tf.concat([[self.special_token.SEQ_START], tokens], 0) if append: tokens = tf.concat([tokens, [self.special_token.SEQ_END]], 0) features['tokens'] = tokens features['seq_len'] = tf.size(tokens) return features ... def make_source_dataset(self, file_path, data_type, is_predict, word_table_func): prepend, append = self.prepend_append_logic(data_type, is_predict) dataset = tf.data.TextLineDataset(file_path).\ map(lambda x: self.parse_example(x, prepend, append), num_parallel_calls=tf.data.experimental.AUTOTUNE).\ map(lambda x: word_table_func(x), num_parallel_calls=tf.data.experimental.AUTOTUNE) return dataset def build_dataset(self, is_predict=0): def input_fn(): word_table_func = self.word_table_lookup(self.build_wordtable()) _ = self.build_tokentable() # initialize here to ensure lookup table is in the same graph encoder_source = self.make_source_dataset(self.data_file['encoder'], 'encoder', is_predict, word_table_func) decoder_source = self.make_source_dataset(self.data_file['decoder'], 'decoder', is_predict, word_table_func) dataset = tf.data.Dataset.zip((encoder_source, decoder_source)).\ filter(self.sample_filter_logic) if not is_predict: dataset = dataset.\ repeat(self.epochs) dataset = dataset. \ padded_batch( batch_size=self.batch_size, padded_shapes=self.padded_shape, padding_values=self.padding_values, drop_remainder=True ). \ prefetch( tf.data.experimental.AUTOTUNE ) else: dataset = dataset.batch(1) return dataset return input_fn def load_pretrain_embedding(self): if self.embedding is None: word_vector = gensim.downloader.load(PretrainModel) embedding = [] for i in self._dictionary.keys(): try: embedding.append( word_vector.get_vector( i ) ) except KeyError: embedding.append( np.random.uniform(low=-0.1, high=0.1, size=300)) self.embedding = np.array(embedding, dtype=np.float32) return self.embedding
Encoder的部分很常規,確認cell類型,而後通過dynamic_rnn迭代,輸出output和state
def gru_encoder(input_emb, input_len, params): gru_cell = build_rnn_cell('gru', params) # state: batch_size * hidden_size, output: batch_size * max_len * hidden_size output, state = tf.nn.dynamic_rnn( cell=gru_cell, # one rnn units inputs=input_emb, # batch_size * max_len * feature_size sequence_length=input_len, # batch_size * seq_len initial_state=None, dtype=params['dtype'], time_major=False # whether reshape max_length to first dim ) return ENCODER_OUTPUT(output=output, state=state)
Decoder的部分能夠分紅helper, decoder, 以及最終dynamic_decode的部分。比較容易踩坑的有幾個點
這裏針對上面提到的把encoder的output_state直接傳入每一個decoder cell作了實現,直接把encoder state和embedding input作了拼接做爲輸入。
def get_helper(encoder_output, input_emb, input_len, batch_size, embedding, mode, params): if mode == tf.estimator.ModeKeys.TRAIN: if params['conditional']: # conditional train helper with encoder output state as direct input # Reshape encoder state as auxiliary input: 1* batch_size * hidden -> batch_size * max_len * hidden decoder_length = tf.shape(input_emb)[1] state_shape = tf.shape(encoder_output.state) encoder_state = tf.tile(tf.reshape(encoder_output.state, [state_shape[1], state_shape[0], state_shape[2]]), [1, decoder_length, 1]) input_emb = tf.concat([encoder_state, input_emb], axis=-1) helper = seq2seq.TrainingHelper( inputs=input_emb, # batch_size * max_len-1 * emb_size sequence_length=input_len-1, # exclude last token time_major=False, name='training_helper' ) else: helper = seq2seq.GreedyEmbeddingHelper( embedding=embedding_func( embedding ), start_tokens=tf.fill([batch_size], params['start_token']), end_token=params['end_token'] ) return helper def get_decoder(decoder_cell, encoder_output, input_emb, input_len, embedding, output_layer, mode, params): batch_size = tf.shape(encoder_output.output)[0] if params['beam_width'] >1 : # If beam search multiple prediction are uesd at each time step decoder = seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embedding_func( embedding ), initial_state=encoder_output, beam_width=params['beam_width'], start_tokens=tf.fill([batch_size], params['start_token']), end_token=params['end_token'], output_layer=output_layer ) else: helper = get_helper(encoder_output, input_emb, input_len, batch_size, embedding, mode, params) decoder = seq2seq.BasicDecoder( cell=decoder_cell, helper=helper, initial_state=encoder_output.state, output_layer=output_layer ) return decoder def gru_decoder(encoder_output, input_emb, input_len, embedding, params, mode): gru_cell = build_rnn_cell( 'gru', params ) if mode == tf.estimator.ModeKeys.TRAIN: max_iteration = None elif mode == tf.estimator.ModeKeys.EVAL: max_iteration = tf.reduce_max(input_len) # decode max sequence length(=padded_length)in EVAL else: max_iteration = params['max_decode_iter'] # decode pre-defined max_decode iter in predict output_layer=tf.layers.Dense(units=params['vocab_size']) # used for infer helper sample or train loss calculation decoder = get_decoder(gru_cell, encoder_output, input_emb, input_len, embedding, output_layer, mode, params) output, state, seq_len = seq2seq.dynamic_decode(decoder=decoder, output_time_major=False, impute_finished=True, maximum_iterations=max_iteration) return DECODER_OUTPUT(output=output, state = state, seq_len=seq_len)
loss這了本身實現的一版sequence_loss,把計算loss和按不一樣維度聚合拆成了兩塊。感受tf.sequence_loss只針對train,對eval的部分並不友好,由於trainHelper能夠保證source和target的長度一致,可是infer時調用GreedyEmbeddingHelper是沒法保證輸出長度的(不知道是否是我哪裏理解錯了,若是是請大神指正(o^^o)), 因此對eval部分也作了特殊處理。
def sequence_loss(logits, target, mask, mode): with tf.variable_scope('Sequence_loss_matrix'): n_class = tf.shape(logits)[2] decode_len = tf.shape(logits)[1] # used for infer only, max_len is determined by decoder logits = tf.reshape(logits, [-1, n_class]) if mode == tf.estimator.ModeKeys.TRAIN: # In train, target target = tf.reshape(target[:, 1:], [-1]) # (batch * (padded_len-1)) * 1 elif mode == tf.estimator.ModeKeys.EVAL: # In eval, target has paded_len, logits have decode_len target = tf.reshape(target[:, : decode_len], [-1]) # batch * (decode_len) *1 else: raise Exception('sequence loss is only used in train or eval, not in pure prediction') loss_mat = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = target, logits = logits) loss_mat = tf.multiply(loss_mat, tf.reshape(mask, [-1])) # apply padded mask on output loss return loss_mat def agg_sequence_loss(loss_mat, mask, axis): with tf.variable_scope('Loss_{}'.format(axis)): if axis == 'scaler': loss = tf.reduce_sum(loss_mat) n_sample = tf.reduce_sum(mask) loss = loss/n_sample else: loss_mat = tf.reshape(loss_mat, tf.shape(mask)) # (batch_size * max_len) * 1-> batch_size * max_len if axis == 'batch': loss = tf.reduce_sum(loss_mat, axis=1) # batch n_sample = tf.reduce_sum(mask, axis=1) # batch loss = tf.math.divide_no_nan(loss, n_sample) # batch elif axis == 'time': loss = tf.reduce_sum(loss_mat, axis=0) # max_len n_sample = tf.reduce_sum(mask, axis=0) # max_len loss = tf.math.divide_no_nan(loss, n_sample) # max_len else: raise Exception('Only scaler/batch/time are supported in axis param') return loss
encoder, decoder, loss都ready,拼一塊就齊活了, 這裏embedding咱們用了前面加載的word2vec來進行初始化。
class QuickThought(object): def __init__(self, params): self.params = params self.init() def init(self): with tf.variable_scope('embedding', reuse=tf.AUTO_REUSE): self.embedding = tf.get_variable(dtype = self.params['dtype'], initializer=tf.constant(self.params['pretrain_embedding']), name='word_embedding' ) add_layer_summary(self.embedding.name, self.embedding) def build_model(self, features, labels, mode): encoder_output = self._encode(features) decoder_output = self._decode(encoder_output, labels, mode ) loss_output = self.compute_loss( decoder_output, labels, mode ) ... def _encode(self, features): with tf.variable_scope('encoding'): encoder = ENCODER_FAMILY[self.params['encoder_type']] seq_emb_input = tf.nn.embedding_lookup(self.embedding, features['tokens']) # batch_size * max_len * emb_size encoder_output = encoder(seq_emb_input, features['seq_len'], self.params) # batch_size return encoder_output def _decode(self, encoder_output, labels, mode): with tf.variable_scope('decoding'): decoder = DECODER_FAMILY[self.params['decoder_type']] if mode == tf.estimator.ModeKeys.TRAIN: seq_emb_output = tf.nn.embedding_lookup(self.embedding, labels['tokens']) # batch_size * max_len * emb_size input_len = labels['seq_len'] elif mode == tf.estimator.ModeKeys.EVAL: seq_emb_output = None input_len = labels['seq_len'] else: seq_emb_output = None input_len = None decoder_output = decoder(encoder_output, seq_emb_output, input_len,\ self.embedding, self.params, mode) return decoder_output def compute_loss(self, decoder_output, labels, mode): with tf.variable_scope('compute_loss'): mask = sequence_mask(decoder_output, labels, self.params, mode) loss_mat = sequence_loss(logits=decoder_output.output.rnn_output, target=labels['tokens'], mask=mask, mode=mode) loss = [] for axis in ['scaler', 'batch', 'time']: loss.append(agg_sequence_loss(loss_mat, mask, axis)) return SEQ_LOSS_OUTPUT(loss_id=loss_mat, loss_scaler=loss[0], loss_per_batch=loss[1], loss_per_time=loss[2])
稀裏糊塗開始用seq2seq,結果盯着shape mismatch的報錯險些看到地老天荒,索性咱老老實實看一遍tf的實現, 如下代碼只保留了核心部分,完整的官方代碼在這裏喲 tf.seq2seq.contrib
Encoding部分就是一個dynamic_rnn,先看下輸入
dynamic_rnn主函數其實只作了輸入/輸出數據的處理部分,包括
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): flat_input = nest.flatten(inputs) if not time_major: flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) batch_size = _best_effort_input_batch_size(flat_input) state = cell.zero_state(batch_size, dtype) inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) (outputs, final_state) = _dynamic_rnn_loop( cell, inputs, state, parallel_iterations=parallel_iterations, swap_memory=swap_memory, sequence_length=sequence_length, dtype=dtype) if not time_major: # (T,B,D) => (B,T,D) outputs = nest.map_structure(_transpose_batch_time, outputs) return (outputs, final_state)
核心計算部分都在_dynamic_rnn_loop,是一個while_loop, 因此須要定義三要素[loop_var, body, condition]
這裏output_ta的shape是(batch, max_len, hidden_units), 對於rnn和GRU,state就是最後一個output, 那shape天然是(1, batch, hidden_units), 但LSTM是有兩個hidden state的,一個用於向前傳遞信息一個用於輸出,因此這裏state的shape會是(2, batch, hidden_units)
loop的核心計算部分是lambda: cell(input_t, state),也就是相應記憶單元的計算。當sequence_length給定時,_rnn_step的額外操做實際上是對已經遍歷完的序列直接copy through(zero_output, last_state)
def _time_step(time, output_ta_t, state): input_t = tuple(ta.read(time) for ta in input_ta) input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) call_cell = lambda: cell(input_t, state) if sequence_length is not None: (output, new_state) = _rnn_step( time=time, sequence_length=sequence_length, min_sequence_length=min_sequence_length, max_sequence_length=max_sequence_length, zero_output=zero_output, state=state, call_cell=call_cell, state_size=state_size, skip_conditionals=True) else: (output, new_state) = call_cell() # Pack state if using state tuples output = nest.flatten(output) output_ta_t = tuple(ta.write(time, out) for ta, out in zip(output_ta_t, output)) return (time + 1, output_ta_t, new_state)
中止loop的條件loop_bound=min(max_sequence_length, max(1,time_steps) , 其中time_step是輸入的max_len維度,也就是padding length, max_sequence_length是輸入batch的最大真實長度,若是是batch_padding這兩個取值應該是同樣的
time_steps = input_shape[0] if sequence_length is not None: min_sequence_length = math_ops.reduce_min(sequence_length) max_sequence_length = math_ops.reduce_max(sequence_length) else: max_sequence_length = time_steps loop_bound = math_ops.minimum(time_steps, math_ops.maximum(1, max_sequence_length)) _, output_final_ta, final_state = control_flow_ops.while_loop( cond=lambda time, *_: time < loop_bound, body=_time_step, loop_vars=(time, output_ta, state), parallel_iterations=parallel_iterations, maximum_iterations=time_steps, swap_memory=swap_memory)
Decoding主要有三個組件,Decoder,Helper和dynamic_decode。還有比較特殊獨立出來的BeamSearch和Attention,這兩個後面用到再說
BasicDecoder主要接口有2個
其中initialize拼接了helper的初始化返回再加上initial_state,也就是encoder最後一步的output_state,helper返回的部分咱們放在後面說。
def initialize(self, name=None): return self._helper.initialize() + (self._initial_state,)
step部分作了以下操做
class BasicDecoderOutput( collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))): pass class BasicDecoder(decoder.Decoder): """Basic sampling decoder.""" def __init__(self, cell, helper, initial_state, output_layer=None): def step(self, time, inputs, state, name=None): with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): cell_outputs, cell_state = self._cell(inputs, state) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) sample_ids = self._helper.sample( time=time, outputs=cell_outputs, state=cell_state) (finished, next_inputs, next_state) = self._helper.next_inputs( time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return (outputs, next_state, next_inputs, finished)
這裏發現BasicDecoder的實現只包括了承上的部分,啓下的部分都放在了Helper裏面,下面咱們具體看下Helper的next_input和Sample接口乾了啥
咱們主要看兩個helper一個用於訓練,一個用於預測,主要實現3個接口
TrainHelper用於訓練,sample接口實際並無用,next_input把sample_id定義爲unused_kwargs.
decoder輸入sequence會在預處理時加入start_token標記seq的開始,對應上圖的\(<go>\)標記,同時加入start_token也爲了造成source和target的錯位,作到輸入T-1個字符預測T個字符。例如source是[\(<go>\), I, love, you],target是[I, love, you, \(<eos>\)]
class TrainingHelper(Helper): def __init__(self, inputs, sequence_length, time_major=False, name=None): ... def initialize(self, name=None): with ops.name_scope(name, "TrainingHelperInitialize"): finished = math_ops.equal(0, self._sequence_length) all_finished = math_ops.reduce_all(finished) next_inputs = control_flow_ops.cond( all_finished, lambda: self._zero_inputs, lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) return (finished, next_inputs) def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): """next_inputs_fn for TrainingHelper.""" with ops.name_scope(name, "TrainingHelperNextInputs", [time, outputs, state]): next_time = time + 1 finished = (next_time >= self._sequence_length) all_finished = math_ops.reduce_all(finished) def read_from_ta(inp): return inp.read(next_time) next_inputs = control_flow_ops.cond( all_finished, lambda: self._zero_inputs, lambda: nest.map_structure(read_from_ta, self._input_tas)) return (finished, next_inputs, state)
GreedyHelper用於預測
initialize返回 (finished, next_inputs)
sample返回sample_id
負責根據每一個decoder cell的output計算出現機率最大的token,做爲下一個decoder cell的輸入,這裏也是上面提到須要output_layer的緣由,由於須要hidden_size -> vocab_size的變換,才能進一步計算softmax
class GreedyEmbeddingHelper(Helper): def __init__(self, embedding, start_tokens, end_token): self._start_tokens = ops.convert_to_tensor( start_tokens, dtype=dtypes.int32, name="start_tokens") self._end_token = ops.convert_to_tensor( end_token, dtype=dtypes.int32, name="end_token") self._start_inputs = self._embedding_fn(self._start_tokens) 。。。 def sample(self, time, outputs, state, name=None): sample_ids = math_ops.cast( math_ops.argmax(outputs, axis=-1), dtypes.int32) return sample_ids def initialize(self, name=None): finished = array_ops.tile([False], [self._batch_size]) return (finished, self._start_inputs) def next_inputs(self, time, outputs, state, sample_ids, name=None): finished = math_ops.equal(sample_ids, self._end_token) all_finished = math_ops.reduce_all(finished) next_inputs = control_flow_ops.cond( all_finished, lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) return (finished, next_inputs, state)
承上啓下的工具都齊活了,要實現對sequence的預測,只剩下一步就是loop,因而有了dynamic_decode,它其實就幹了個while_loop的活,因而仍是loop三兄弟[loop_vars, condition, body]
loop_vars=[initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths]
condition: 判斷是否全部finished都爲True,都遍歷完則中止loop
body: loop的核心計算邏輯
step:調用Decoder進行每一步的decode計算
finished: 這裏finished主要由三個邏輯判斷(tracks_own_finished我沒用過先忽略了哈哈)其他兩個是:
sequence_length: 記錄實際預測sequence長度,沒有finished的sequence+1
impute_finished: 若是sequence已遍歷完, 後面的output補0,後面的state再也不計算直接pass through當前state
def body(time, outputs_ta, state, inputs, finished, sequence_lengths): (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) next_sequence_lengths = array_ops.where( math_ops.logical_and(math_ops.logical_not(finished), next_finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where(finished, cur, new) if impute_finished: next_state = nest.map_structure( _maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)
歡迎留言吐槽以及評論喲~