轉載:tensorflow保存訓練後的模型

訓練完一個模型後,爲了之後重複使用,一般咱們須要對模型的結果進行保存。若是用Tensorflow去實現神經網絡,所要保存的就是神經網絡中的各項權重值。建議能夠使用Saver類保存和加載模型的結果。

一、使用tf.train.Saver.save()方法保存模型網絡

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)dom

  • sess: 用於保存變量操做的會話。
  • save_path: String類型,用於指定訓練結果的保存路徑。
  • global_step: 若是提供的話,這個數字會添加到save_path後面,用於構建checkpoint文件。這個參數有助於咱們區分不一樣訓練階段的結果。

二、使用tf.train.Saver.restore方法價值模型函數

tf.train.Saver.restore(sess, save_path)this

  • sess: 用於加載變量操做的會話。
  • save_path: 同保存模型是用到的的save_path參數。

下面經過一個代碼演示這兩個函數的使用方法rest

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if isTrain:
        for i in xrange(train_steps):
            sess.run(train, feed_dict={x: x_data})
            if (i + 1) % checkpoint_steps == 0:
                saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
    else:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print(sess.run(w))
        print(sess.run(b))
相關文章
相關標籤/搜索