在tensorflow/nmt項目中,訓練數據和推斷數據的輸入使用了新的Dataset API,應該是tensorflow 1.2以後引入的API,方便數據的操做。若是你還在使用老的Queue和Coordinator的方式,建議升級高版本的tensorflow而且使用Dataset API。python
本教程將從訓練數據和推斷數據兩個方面,詳解解析數據的具體處理過程,你將看到文本數據如何轉化爲模型所須要的實數,以及中間的張量的維度是怎麼樣的,batch_size和其餘超參數又是如何做用的。git
先來看看訓練數據的處理。訓練數據的處理比推斷數據的處理稍微複雜一些,弄懂了訓練數據的處理過程,就能夠很輕鬆地理解推斷數據的處理。
訓練數據的處理代碼位於nmt/utils/iterator_utils.py文件內的get_iterator
函數。github
咱們先來看看這個函數所須要的參數是什麼意思:網絡
參數 | 解釋 |
---|---|
src_dataset |
源數據集 |
tgt_dataset |
目標數據集 |
src_vocab_table |
源數據單詞查找表,就是個單詞和int類型數據的對應表 |
tgt_vocab_table |
目標數據單詞查找表,就是個單詞和int類型數據的對應表 |
batch_size |
批大小 |
sos |
句子開始標記 |
eos |
句子結尾標記 |
random_seed |
隨機種子,用來打亂數據集的 |
num_buckets |
桶數量 |
src_max_len |
源數據最大長度 |
tgt_max_len |
目標數據最大長度 |
num_parallel_calls |
併發處理數據的併發數 |
output_buffer_size |
輸出緩衝區大小 |
skip_count |
跳過數據行數 |
num_shards |
將數據集分片的數量,分佈式訓練中有用 |
shard_index |
數據集分片後的id |
reshuffle_each_iteration |
是否每次迭代都從新打亂順序 |
上面的解釋,若是有不清楚的,能夠查看我以前一片介紹超參數的文章:
tensorflow_nmt的超參數詳解併發
咱們首先搞清楚幾個重要的參數是怎麼來的。src_dataset
和tgt_dataset
是咱們的訓練數據集,他們是逐行一一對應的。好比咱們有兩個文件src_data.txt
和tgt_data.txt
分別對應訓練數據的源數據和目標數據,那麼它們的Dataset如何建立的呢?其實利用Dataset API很簡單:app
src_dataset=tf.data.TextLineDataset('src_data.txt') tgt_dataset=tf.data.TextLineDataset('tgt_data.txt')
這就是上述函數中的兩個參數src_dataset
和tgt_dataset
的由來。dom
src_vocab_table
和tgt_vocab_table
是什麼呢?一樣顧名思義,就是這兩個分別表明源數據詞典的查找表和目標數據詞典的查找表,實際上查找表就是一個字符串到數字的映射關係。固然,若是咱們的源數據和目標數據使用的是同一個詞典,那麼這兩個查找表的內容是如出一轍的。很容易想到,確定也有一種數字到字符串的映射表,這是確定的,由於神經網絡的數據是數字,而咱們須要的目標數據是字符串,所以它們之間確定有一個轉換的過程,這個時候,就須要咱們的reverse_vocab_table來做用了。分佈式
咱們看看這兩個表是怎麼構建出來的呢?代碼很簡單,利用tensorflow庫中定義的lookup_ops便可:函數
def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" src_vocab_table = lookup_ops.index_table_from_file( src_vocab_file, default_value=UNK_ID) if share_vocab: tgt_vocab_table = src_vocab_table else: tgt_vocab_table = lookup_ops.index_table_from_file( tgt_vocab_file, default_value=UNK_ID) return src_vocab_table, tgt_vocab_table
咱們能夠發現,建立這兩個表的過程,就是將詞典中的每個詞,對應一個數字,而後返回這些數字的集合,這就是所謂的詞典查找表。效果上來講,就是對詞典中的每個詞,從0開始遞增的分配一個數字給這個詞。fetch
那麼到這裏你有可能會有疑問,咱們詞典中的詞和咱們自定義的標記sos
等是否是有可能被映射爲同一個整數而形成衝突?這個問題該如何解決?聰明如你,這個問題是存在的。那麼咱們的項目是如何解決的呢?很簡單,那就是將咱們自定義的標記當成詞典的單詞,而後加入到詞典文件中,這樣一來,lookup_ops
操做就把標記當成單詞處理了,也就就解決了衝突!
具體的過程,本文後面會有一個例子,能夠爲您呈現具體過程。
若是咱們指定了share_vocab
參數,那麼返回的源單詞查找表和目標單詞查找表是同樣的。咱們還能夠指定一個default_value,在這裏是UNK_ID
,實際上就是0
。若是不指定,那麼默認值爲-1
。這就是查找表的建立過程。若是你想具體的知道其代碼實現,能夠跳轉到tensorflow的C++核心部分查看代碼(使用PyCharm或者相似的IDE)。
該函數處理訓練數據的主要代碼以下:
if not output_buffer_size: output_buffer_size = batch_size * 1000 src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) if skip_count is not None: src_tgt_dataset = src_tgt_dataset.skip(skip_count) src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration) src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Filter zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_parallel_calls=num_parallel_calls)