從鍋爐工到AI專家(10)

RNN循環神經網絡(Recurrent Neural Network)

如同word2vec中提到的,不少數據的原型,先後之間是存在關聯性的。關聯性的打破必然形成關鍵指徵的丟失,從而在後續的訓練和預測流程中下降準確率。
除了提過的天然語言處理(NLP)領域,自動駕駛前一時間點的雷達掃描數據跟後一時間點的掃描數據、音樂旋律的時間性、股票前一天跟後一天的數據,都屬於這類的典型案例。
所以在傳統的神經網絡中,每個節點,若是把上一次的運算結果記錄下來,在下一次數據處理的時候,跟上一次的運算結果結合在一塊兒混合運算,就能夠體現出上一次的數據對本次的影響。

如上圖所示,圖中每個節點就至關於神經網絡中的一個節點,t-1 、 t 、 t+1是指該節點在時間序列中的動做,你能夠理解爲第n批次的數據。
因此上面圖中的3個節點,在實現中實際是同1個節點。
指的是,在n-1批次數據到來的時候,節點進行計算,完成輸出,同時保留了一個state。
在下一批次數據到來的時候,state值跟新到來的數據一塊兒進行運算,再次完成輸出,再次保留一個state參與下一批次的運算,如此循環。這也是循環神經網絡名稱的由來。python

RNN算法存在一個問題,那就是同一節點在某一時間點所保存的狀態,隨着時間的增加,它所能形成的影響就越小,逐漸衰減至無。這對於一些長距離上下文相關的應用,仍然是不知足要求的。
這就又發展出了LSTM算法。git

LSTM長短時間記憶網絡(Long Short-Term Memory)


如圖所示:LSTM區別於RNN的地方,主要就在於它在算法中加入了一個判斷信息有用與否的「處理器」,這個處理器做用的結構被稱爲cell。
一個cell當中被放置了三個「門電路」,分別叫作輸入門、遺忘門和輸出門。一個信息進入LSTM的網絡當中,能夠根據規則來判斷是否有用。只有符合算法認證的信息纔會留下,不符的信息則經過遺忘門被遺忘。github

  • 遺忘門決定讓哪些信息繼續經過這個cell。
  • 輸入門決定讓多少新的信息加入到 cell狀態中來。
  • 輸出門決定咱們要輸出什麼樣的值。

經過這樣簡單的節點結構改善,就有效的解決了長時序依賴數據在神經網絡中的表現。算法

LSTM隨後還出現了很多變種,進一步增強了功能或者提升了效率。好比當前比較有名的GRU(Gated Recurrent Unit )是2014年提出的。GRU在不下降處理效果的同時,減小了一個門結構。只有重置門(reset gate)和更新門(update gate)兩個門,而且把細胞狀態和隱藏狀態進行了合併。這使得算法的實現更容易,結構更清晰,運算效率也有所提升。
目前的應用中,較多的使用是LSTM或者GRU。RNN網絡其實已經不多直接用到了。數據庫

實現一個RNN網絡

官方的RNN網絡教程是實現了一個NLP的應用,技術上很切合RNN的典型特徵。不過從程序邏輯上太複雜了,並且計算結果也很不直觀。
爲了能儘快的抓住RNN網絡的本質,本例仍然延續之前用過的MNIST程序,把其中的識別模型替換爲RNN-LSTM網絡,相信能夠更快的讓你們上手RNN-LSTM。
本例中的源碼來自aymericdamien的github倉庫,爲了更接近咱們原來的示例代碼,適當作了修改。在此對原做者表示感謝。
官方的課程建議在讀完這裏的內容以後再去學習,而且也很值得深刻的研究。
源碼:編程

#!/usr/bin/env python
# -*- coding=UTF-8 -*-

""" Recurrent Neural Network.

A Recurrent Neural Network (LSTM) implementation example using TensorFlow library.
This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)

Links:
    [Long Short Term Memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf)
    [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).

Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
"""

from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
#這裏指向之前下載的數據,節省下載時間
#使用時請將後面的路徑修改成本身數據所在路徑
mnist = input_data.read_data_sets("../mnist/data", one_hot=True)

'''
To classify images using a recurrent neural network, we consider every image
row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then
handle 28 sequences of 28 steps for every sample.
'''

# Training Parameters
#訓練梯度
learning_rate = 0.001
#訓練總步驟
training_steps = 10000
#每批次量
batch_size = 128
#每200步顯示一次訓練進度
display_step = 200

# Network Parameters
#下面兩個值實際就是28x28的圖片,可是分紅每組進入RNN的數據28個,
#而後一共28個批次(時序)的數據,利用這種方式,找出單方向相鄰兩個點之間的規律
#這種方式當時不如CNN的效果,但咱們這裏是爲了展現RNN的應用
num_input = 28 # MNIST data input (img shape: 28*28)
timesteps = 28 # timesteps
#LSTM網絡的參數,隱藏層數量
num_hidden = 128 # hidden layer num of features
#最終分爲10類,0-9十個字付
num_classes = 10 # MNIST total classes (0-9 digits)

# tf Graph input
#訓練數據輸入,跟MNIST相同
X = tf.placeholder("float", [None, timesteps, num_input])
Y = tf.placeholder("float", [None, num_classes])

# Define weights
#權重和偏移量
weights = tf.Variable(tf.random_normal([num_hidden, num_classes]))
biases = tf.Variable(tf.random_normal([num_classes]))


def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, timesteps, n_input)
    # Required shape: 'timesteps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'timesteps' tensors of shape (batch_size, n_input)
    #進入的數據是X[128(批量),784(28x28)]這樣的數據
    #下面函數轉換成x[128,28]的數組,數組長度是28
    #至關於一個[28,128,28]的張量
    x = tf.unstack(x, timesteps, 1)

    # Define a lstm cell with tensorflow
    #定義一個lstm Cell,其中有128個單元,這個數值能夠修改調優
    lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

    # Get lstm cell output
    #使用單元計算x,最後得到輸出及狀態
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    #仍然是咱們熟悉的算法,這裏至關於該節點的激活函數(就是原來rule的位置)
    return tf.matmul(outputs[-1], weights) + biases

#使用RNN網絡定義一個算法模型
logits = RNN(X, weights, biases)
#預測算法
prediction = tf.nn.softmax(logits)

# Define loss and optimizer
#代價函數、優化器及訓練器,跟原來基本是相似的
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)

# Evaluate model (with test logits, for dropout to be disabled)
#使用上面定義的預測算法進行預測,跟樣本標籤相同即爲預測正確
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
#最後換算成正確率
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for step in range(1, training_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        #首先把數據從[128,784]轉換成[128,28,28]的形狀,這跟之前線性迴歸是不一樣的
        batch_x = batch_x.reshape((batch_size, timesteps, num_input))
        # Run optimization op (backprop)
        #逐批次訓練
        sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
        if step % display_step == 0 or step == 1:
            # Calculate batch loss and accuracy
            #每200個批次顯示一下進度,當前的代價值機正確率
            loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
                                                                 Y: batch_y})
            print("Step " + str(step) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.3f}".format(acc))

    print("Optimization Finished!")

    # Calculate accuracy for 128 mnist test images
    #訓練完成,使用測試組數據進行預測
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

跟原來的MNIST代碼對比,本源碼有如下幾個修改:數組

  • 常量在前面集中定義,這是編程習慣上的調整,跟TensorFlow及RNN-LSTM無關
  • 核心算法替換成了RNN,在RNN函數中實現,其中主要作了3個動做:
    • 首先把數據切成28個數據一個批次。原來從訓練集中讀取的數據是[128批次,784數據]的張量。
      隨後在主循環中改爲了:[128,28,28]的張量喂入RNN。註釋中有說明,這是利用RNN的特徵,試圖尋找每張圖片在單一方向上相鄰兩個點之間是否存在規律。
      RNN中第一個動做就是按照時序分紅28個批次。變成了[28,128,28]的樣式。
    • 隨後定義了一個基本的LSTM Cell,包含128個單元,這裏能夠理解爲神經網絡中的隱藏層。
    • 最後使用咱們熟悉的線性迴歸做用到每個輸出單元中去,在這裏,這個線性迴歸也至關於神經網絡中每一個節點的激活函數。
  • 交叉熵的計算又換了一種算法:softmax_cross_entropy_with_logits,同咱們前面用過的sparse_softmax_cross_entropy功能是接近的,基本能夠互相代換。
  • 隨後的訓練和預測,基本同原來的算法是相同。

運算結果:bash

Step 9000, Minibatch Loss= 0.4518, Training Accuracy= 0.859
Step 9200, Minibatch Loss= 0.4717, Training Accuracy= 0.852
Step 9400, Minibatch Loss= 0.5074, Training Accuracy= 0.859
Step 9600, Minibatch Loss= 0.4006, Training Accuracy= 0.883
Step 9800, Minibatch Loss= 0.3571, Training Accuracy= 0.875
Step 10000, Minibatch Loss= 0.3069, Training Accuracy= 0.906
Optimization Finished!
Testing Accuracy: 0.8828125

訓練的結果並非很高,由於對於圖像識別,RNN並非很好的算法,這裏只是演示一個基本的RNN-LSTM模型。網絡

自動寫詩

上面的例子讓你們對於RNN/LSTM作了入門。實際上RNN/LSTM並不適合用於圖像識別,一個典型的LSTM應用案例應當是NLP。咱們下面再舉一個這方面的案例。
本節是一個利用唐詩數據庫,訓練一個RNN/LSTM網絡,隨後利用訓練好的網絡自動寫詩的案例。
源碼來自互聯網,做者:斗大的熊貓,在此表示感謝。
爲了適應python2.x+TensorFlow1.4.1的運行環境,另外也爲了你們讀起來方便把訓練部分跟生成部分集成到了一塊兒,所以源碼有所修改。也建議你們去原做者的博客去讀一讀相關的文章,會頗有收穫,在引文中也有直接的連接。
源碼講解:app

  • 首先是唐詩的數據庫,能夠在此連接下載到:全唐詩(43030首)
  • readPoetry()函數中,讀取了所有的唐詩,分離並拋棄掉標題部分,由於這部分每每不符合詩詞的通常格式,參與詩詞的訓練沒有意義。
    隨後對詩詞進行基本的歸一化,諸如剔除空格、根據字數分類。原詩中包含說明、介紹、引用等不署於詩詞的部分,由於這部分數據徹底不規範不能自動處理,因此這樣的詩詞幹脆剔除掉不參與訓練。
    最後獲得的樣本集,每首詩保持了中間的逗號和句號,用於體現逗號、句號跟以前的字的規律。此外認爲在開頭和結尾增長了"["和"]"字符。用於體現每首詩第一個字和最後一個字跟相鄰字之間的規律。
  • 接着把詩文向量化,就是上一篇word2vec的工做。但這個源碼估計爲了下降工做量,沒有進行分詞,程序假定每一個字就是一個詞,多字詞的關係會被丟失,但這在後面「自動寫詩」的環節會比較容易處理,不然可能形成每句詩中由於詞語的存在而字數不一樣。另一點就是沒有把同義詞在向量空間中拉近相關的距離,這裏也是爲了簡化操做。也能夠說還存在改進的空間。
  • genTrainData()以64首詩爲一個批次,生成了訓練數據集x_batches/y_batches,由於整體算詩詞的數據集比較小。這裏沒有動態逐批次生成,而是一次生成到兩個數組中去。在訓練結束生成古詩的時候,這部分實際是沒有用的,但訓練跟生成集成在同一個程序中,就忽略這點工做了。須要注意的是,生成古詩的時候,批次會設定爲1,由於是經過一個漢字預測下一個漢字。
  • neural_network()函數中定義了RNN/LSTM網絡,實際上這個主函數考慮了使用RNN / LSTM / GRU三種網絡的構建選擇,能夠任意選擇其一。在這裏使用了python函數能夠跟變量同樣賦值並調用的特性,讀源碼的時候能夠注意一下。
    與上一個例子還有一點不一樣,就是這裏使用了兩層的RNN網絡,回憶一下多層神經網絡,理解這個概念應當不難。這項工做是由tf.nn.rnn_cell.MultiRNNCell函數完成的。
    tf.get_variable()函數也是定義TensorFlow變量,咱們以前一直使用tf.Variable(),二者功能相似,前者更適合在做用域的管理下共享變量。
    接着要介紹的是個重點:tf.nn.dynamic_rnn,咱們前面說過,由於是時序輸入的計算模式,因此輸入數據能夠是不等長的,這是RNN網絡的特徵之一。咱們以前全部的案例,每一個訓練批次的數據必須是定長,上一個RNN案例中也使用了rnn.static_rnn,這表示使用定長的數據集。
    後面的激活函數再次是咱們熟悉的softmax,此次等因而把上面數字化以後的唐詩中的漢字作成一個庫,分類到其中之一,即爲推測出的下一個字。
    總結一下模型部分:唐詩數字化的時候,完整的保留了每首詩開頭文字、結尾文字、每句的結尾文字之間的關係。所創建的RNN模型,實際上會以上一個文字,預測下一個文字,甚至標點符號都是預測而獲得的。
  • 隨後的訓練部分train_neural_network()沒有太多新概念,要注意的是每次調用模型的訓練,會保留其last_state,並在下個批次訓練的時候,迭代進去。這是咱們前面講RNN模型的時候說過的。而這種模式,是在以前的各類模型中沒有出現過的。
  • gen_poetry()自動生成詩句是一個很完整的預測,初始的值會是一個字符"[",表示一個詩的開始,咱們樣本中,每首詩的開始都是人爲增長的「[」字符。RNN模型確定不會對這麼高頻的規律搞錯。這種模式生成的古詩雖然遠遠比不上人的做品,但可讀性仍是比較好的。
  • 藏頭詩部分gen_poetry_with_head(),這部分生成的會比較牽強。緣由是,人爲指定的藏頭詩第一個字,不可能恰好吻合唐詩數據庫中每句第一個字的規律,所以直接預測出來,極可能沒有完成一句話,就已是句號或者逗號。
    程序只能根據預置的句長(這裏指定七言),跳過逗號、句號以及結束符號「]」,跳過以後再次從新生成,其實已經不符合一句話中的規律,但爲了達到藏頭詩的效果,也只能如此。
  • 訓練模型使用的批次是64。生成時候所使用的預測模型批次是1,由於使用一個漢字去預測後一個。這個在main()中會自動調整。

其他的部分相信憑藉註釋和之前的經驗應當能看懂了:

#!/usr/bin/env python
# -*- coding=UTF-8 -*-

# source from: 
#  http://blog.topspeedsnail.com/archives/10542
# poetry.txt from:
#  https://pan.baidu.com/s/1o7QlUhO
# revised: andrew
#  https://formoon.github.io
#  add python 2.x support and tf 1.4.1 support
#------------------------------------------------------------------#

import collections
import numpy as np
import tensorflow as tf
import argparse
import codecs
import os,time
import sys
reload(sys)
sys.setdefaultencoding('utf-8')

#-------------------------------數據預處理---------------------------#

poetry_file ='poetry.txt'

# 詩集
poetrys = []
def readPoetry():
    global poetrys
    #with open(poetry_file, "r", encoding='utf-8',) as f:
    with codecs.open(poetry_file, "r","utf-8") as f:
        for line in f:
            try:
                content = line.strip().split(':')[1]
                #title, content = line.strip().split(':')
                content = content.replace(' ','')
                if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content:
                    continue
                if len(content) < 5 or len(content) > 79:
                    continue
                content = '[' + content + ']'
                poetrys.append(content)
            except Exception as e:
                pass
    # 按詩的字數排序
    poetrys = sorted(poetrys,key=lambda line: len(line))

#for item in poetrys:
#    print(item)

# 統計每一個字出現次數
readPoetry()
all_words = []
for poetry in poetrys:
    all_words += [word for word in poetry]
#    print poetry
#    for word in poetry:
#        print(word)
#        all_words += word
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
#print words

# 取前多少個經常使用字
words = words[:len(words)] + (' ',)
# 每一個字映射爲一個數字ID
word_num_map = dict(zip(words, range(len(words))))
#print(word_num_map)
# 把詩轉換爲向量形式,參考word2vec
to_num = lambda word: word_num_map.get(word, len(words))
poetrys_vector = [ list(map(to_num, poetry)) for poetry in poetrys]
#[[314, 3199, 367, 1556, 26, 179, 680, 0, 3199, 41, 506, 40, 151, 4, 98, 1],
#[339, 3, 133, 31, 302, 653, 512, 0, 37, 148, 294, 25, 54, 833, 3, 1, 965, 1315, 377, 1700, 562, 21, 37, 0, 2, 1253, 21, 36, 264, 877, 809, 1]
#....]

# 每次取64首詩進行訓練
batch_size = 64
n_chunk = len(poetrys_vector) // batch_size
x_batches = []
y_batches = []
def genTrainData(b):
    global batch_size,n_chunk,x_batches,y_batches,poetrys_vector
    batch_size=b
    for i in range(n_chunk):
        start_index = i * batch_size
        end_index = start_index + batch_size

        batches = poetrys_vector[start_index:end_index]
        length = max(map(len,batches))
        xdata = np.full((batch_size,length), word_num_map[' '], np.int32)
        for row in range(batch_size):
            xdata[row,:len(batches[row])] = batches[row]
        ydata = np.copy(xdata)
        ydata[:,:-1] = xdata[:,1:]
        """
        xdata ydata
        [6,2,4,6,9] [2,4,6,9,9]
        [1,4,2,8,5] [4,2,8,5,5]
        """
        x_batches.append(xdata)
        y_batches.append(ydata)


#---------------------------------------RNN--------------------------------------#

# 定義RNN
def neural_network(input_data, model='lstm', rnn_size=128, num_layers=2):
    if model == 'rnn':
        cell_fun = tf.nn.rnn_cell.BasicRNNCell
    elif model == 'gru':
        cell_fun = tf.nn.rnn_cell.GRUCell
    elif model == 'lstm':
        cell_fun = tf.nn.rnn_cell.BasicLSTMCell

    cell = cell_fun(rnn_size, state_is_tuple=True)
    cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

    initial_state = cell.zero_state(batch_size, tf.float32)

    with tf.variable_scope('rnnlm'):
        softmax_w = tf.get_variable("softmax_w", [rnn_size, len(words)+1])
        softmax_b = tf.get_variable("softmax_b", [len(words)+1])
        with tf.device("/cpu:0"):
            embedding = tf.get_variable("embedding", [len(words)+1, rnn_size])
            inputs = tf.nn.embedding_lookup(embedding, input_data)

    outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state, scope='rnnlm')
    output = tf.reshape(outputs,[-1, rnn_size])

    logits = tf.matmul(output, softmax_w) + softmax_b
    probs = tf.nn.softmax(logits)
    return logits, last_state, probs, cell, initial_state
#訓練
def train_neural_network():
    global datafile
    input_data = tf.placeholder(tf.int32, [64, None])
    output_targets = tf.placeholder(tf.int32, [64, None])
    
    logits, last_state, _, _, _ = neural_network(input_data)
    targets = tf.reshape(output_targets, [-1])
    #loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype=tf.float32)], len(words))
    loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype=tf.float32)], len(words))
    cost = tf.reduce_mean(loss)
    learning_rate = tf.Variable(0.0, trainable=False)
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(zip(grads, tvars))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        #saver = tf.train.Saver(tf.all_variables())
        saver = tf.train.Saver()

        for epoch in range(50):
            sess.run(tf.assign(learning_rate, 0.002 * (0.97 ** epoch)))
            n = 0
            for batche in range(n_chunk):
                train_loss, _ , _ = sess.run([cost, last_state, train_op], feed_dict={input_data: x_batches[n], output_targets: y_batches[n]})
                n += 1
                print(epoch, batche, train_loss)
            if epoch % 7 == 0:
                #保存的數據,文件名中有批次的標誌
                saver.save(sess, datafile, global_step=epoch)

#-------------------------------生成古詩---------------------------------#
# 使用訓練完成的模型
 
def gen_poetry():
    global datafile
    input_data = tf.placeholder(tf.int32, [1, None])
    output_targets = tf.placeholder(tf.int32, [1, None])

    def to_word(weights):
        t = np.cumsum(weights)
        s = np.sum(weights)
        sample = int(np.searchsorted(t, np.random.rand(1)*s))
        return words[sample]

    _, last_state, probs, cell, initial_state = neural_network(input_data)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        #讀取最後一個批次的訓練數據
        saver.restore(sess, datafile+"-49")

        state_ = sess.run(cell.zero_state(1, tf.float32))

        x = np.array([list(map(word_num_map.get, '['))])
        [probs_, state_] = sess.run([probs, last_state], feed_dict={input_data: x, initial_state: state_})
        word = to_word(probs_)
        #word = words[np.argmax(probs_)]
        poem = ''
        while word != ']':
            poem += word
            if word == ',' or word=='。':
                poem += '\n'
            x = np.zeros((1,1))
            x[0,0] = word_num_map[word]
            [probs_, state_] = sess.run([probs, last_state], feed_dict={input_data: x, initial_state: state_})
            word = to_word(probs_)
            #word = words[np.argmax(probs_)]
        return poem
 

#-------------------------------生成藏頭詩---------------------------------#
def gen_poetry_with_head(head,phase):
    global datafile
    input_data = tf.placeholder(tf.int32, [1, None])
    output_targets = tf.placeholder(tf.int32, [1, None])

    def to_word(weights):
        t = np.cumsum(weights)
        s = np.sum(weights)
        sample = int(np.searchsorted(t, np.random.rand(1)*s))
        return words[sample]

    _, last_state, probs, cell, initial_state = neural_network(input_data)

    with tf.Session() as sess:
#        sess.run(tf.initialize_all_variables())
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(sess, datafile+"-49")

        state_ = sess.run(cell.zero_state(1, tf.float32))
        poem = ''
        i = 0
        p = 0
        head=unicode(head,"utf-8");
        for word in head:
            while True:
                if word != ',' and word != '。' and word != ']':
                    poem += word
                    p += 1
                    if p == phase:
                        p = 0
                        break
                else:
                    word='['
                x = np.array([list(map(word_num_map.get, word))])
                [probs_, state_] = sess.run([probs, last_state], feed_dict={input_data: x, initial_state: state_})
                word = to_word(probs_)
            if i % 2 == 0:
                poem += ',\n'
            else:
                poem += '。\n'
            i += 1
        return poem

FLAGS = None 
datafile='./data/module-49'
def datafile_exist():
    return os.path.exists(datafile+"-49.index")

def main(_):
#    if FLAGS.train or (not datafile_exist()):
    if FLAGS.train:
        genTrainData(64)
        print("poems: ",len(poetrys))
        train_neural_network()
        exit()
    if datafile_exist():
        genTrainData(1)
        if FLAGS.generate:
            print(gen_poetry())
        else:
            print(gen_poetry_with_head(FLAGS.head,7))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-a','--head', type=str, default='大寒將至',
                      help='poetry with appointed head char')
    parser.add_argument('-t','--train', action='store_true',default=False,
                      help='Force do train')
    parser.add_argument('-g','--generate', action='store_true',default=False,
                      help='Force do train')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

使用方法:
-a參數是指定藏頭詩開始的字;
-g參數直接自動生成;
-t強制開始訓練。(注意訓練的時間仍是比較長的)

生成的效果請看:

> ./poetry.py -g
沉眉默去迎風雪,
江上才風著故人。
手把柯子不看淚,
笑逢太守也憐君。
秋風不定紅鈿囀,
茶雪欹眠愁斷人。
語苦微成求不死,
醉看花發漸盈衣。

#藏頭詩
> ./poetry.py -a "春節快樂"
春奔桃芳水路猶,
節似鳥飛酒綠出。
快龜縷日發春時,
樂見來還日只相。

至少有了個古詩的樣子了。

(待續...)

引文及參考

TensorFlow練習3: RNN, Recurrent Neural Networks
TensorFlow練習7: 基於RNN生成古詩詞
如何用TensorFlow構建RNN?這裏有一份極簡的教程
(譯)理解 LSTM 網絡 (Understanding LSTM Networks by colah)

相關文章
相關標籤/搜索