Tensorflow在訓練好的模型上進行測試

【轉載自 https://blog.csdn.net/sinat_35821976/article/details/80765145】網絡

  Tensorflow能夠使用訓練好的模型對新的數據進行測試,有兩種方法:第一種方法是調用模型和訓練在同一個py文件中,中狀況比較簡單;第二種是訓練過程和調用模型過程分別在兩個py文件中。本文將講解第二種方法。函數

模型的保存
  tensorflow提供可保存訓練模型的接口,使用起來也不是很難,直接上代碼講解:測試

#網絡結構
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#損失函數與優化函數
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess:  
        sess.run(init)  
        saver.save(sess,"save/model.ckpt")  
        train_step.run({x: train_x, y_: train_y})

以上代碼就完成了模型的保存,值得注意的是下面這行代碼fetch

tf.add_to_collection('network-output', y)

這行代碼保存了神經網絡的輸出,這個在後面使用導入模型過程當中起到關鍵做用。優化

模型的導入

  模型訓練並保存後就能夠導入來評估模型在測試集上的表現,網上不少文章只用簡單的四則運算來作例子,讓人看的頭大。仍是先上代碼:spa

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./model.ckpt.meta')
    saver.restore(sess, './model.ckpt')# .data文件
    pred = tf.get_collection('network-output')[0]

    graph = tf.get_default_graph()
    x = graph.get_operation_by_name('x').outputs[0]
    y_ = graph.get_operation_by_name('y_').outputs[0]

    y = sess.run(pred, feed_dict={x: test_x, y_: test_y})

講解一下關鍵的代碼,首先是pred = tf.get_collection('pred_network')[0],這行代碼得到訓練過程當中網絡輸出的「接口」,簡單理解就是,經過tf.get_collection() 這個方法獲取了整個網絡結構。得到網絡結構後咱們就須要餵它對應的數據y = sess.run(pred, feed_dict={x: test_x, y_: test_y}) 在訓練過程當中咱們的輸入是.net

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')

所以導入模型後所需的輸入也要與之對應可以使用如下代碼得到:rest

    x = graph.get_operation_by_name('x').outputs[0]
    y_ = graph.get_operation_by_name('y_').outputs[0]

使用模型的最後一步就是輸入測試集,而後按照訓練好的網絡進行評估code

    sess.run(pred, feed_dict={x: test_x, y_: test_y})

理解下這行代碼,sess.run() 的函數原型爲orm

run(fetches, feed_dict=None, options=None, run_metadata=None)

Tensorflow對 feed_dict 執行fetches操做,所以在導入模型後的運算就是,按照訓練的網絡計算測試輸入的數據。

相關文章
相關標籤/搜索