Tensorflow系列——Saver的用法

摘抄自:https://blog.csdn.net/u011500062/article/details/51728830/dom

一、實例測試

 1 import tensorflow as tf
 2 import numpy as np
 3 
 4 x = tf.placeholder(tf.float32, shape=[None, 1])
 5 y = 4 * x + 4
 6 
 7 w = tf.Variable(tf.random_normal([1], -1, 1))
 8 b = tf.Variable(tf.zeros([1]))
 9 y_predict = w * x + b
10 
11 loss = tf.reduce_mean(tf.square(y - y_predict))
12 optimizer = tf.train.GradientDescentOptimizer(0.5)
13 train = optimizer.minimize(loss)
14 
15 isTrain = False
16 train_steps = 100
17 checkpoint_steps = 50
18 checkpoint_dir = './checkpoint_dir/'
19 
20 saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
21 x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
22 
23 with tf.Session() as sess:
24     sess.run(tf.initialize_all_variables())
25     if isTrain:
26         for i in range(train_steps):
27             sess.run(train, feed_dict={x: x_data})
28             if (i + 1) % checkpoint_steps == 0:
29                 saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i + 1) 30     else:
31         ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
33         if ckpt and ckpt.model_checkpoint_path: 34  saver.restore(sess, ckpt.model_checkpoint_path) 35             print("Restore Sucessfully")
36         else:
37             pass
38         print(sess.run(w))
39         print(sess.run(b))

二、運行結果this

 

 

三、解釋spa

訓練階段,每通過checkpoint_steps 步保存一次變量,保存的文件夾爲checkpoint_dir.net

測試階段,ckpt.model_checkpoint_path:表示模型存儲的位置,不須要提供模型的名字,它會去查看checkpoint文件,看看最新的是誰,叫作什麼,而後載入變量3d

相關文章
相關標籤/搜索