摘抄自:https://blog.csdn.net/u011500062/article/details/51728830/dom
一、實例測試
1 import tensorflow as tf 2 import numpy as np 3 4 x = tf.placeholder(tf.float32, shape=[None, 1]) 5 y = 4 * x + 4 6 7 w = tf.Variable(tf.random_normal([1], -1, 1)) 8 b = tf.Variable(tf.zeros([1])) 9 y_predict = w * x + b 10 11 loss = tf.reduce_mean(tf.square(y - y_predict)) 12 optimizer = tf.train.GradientDescentOptimizer(0.5) 13 train = optimizer.minimize(loss) 14 15 isTrain = False 16 train_steps = 100 17 checkpoint_steps = 50 18 checkpoint_dir = './checkpoint_dir/' 19 20 saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 21 x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) 22 23 with tf.Session() as sess: 24 sess.run(tf.initialize_all_variables()) 25 if isTrain: 26 for i in range(train_steps): 27 sess.run(train, feed_dict={x: x_data}) 28 if (i + 1) % checkpoint_steps == 0: 29 saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i + 1) 30 else: 31 ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 33 if ckpt and ckpt.model_checkpoint_path: 34 saver.restore(sess, ckpt.model_checkpoint_path) 35 print("Restore Sucessfully") 36 else: 37 pass 38 print(sess.run(w)) 39 print(sess.run(b))
二、運行結果this
三、解釋spa
訓練階段,每通過checkpoint_steps 步保存一次變量,保存的文件夾爲checkpoint_dir.net
測試階段,ckpt.model_checkpoint_path:表示模型存儲的位置,不須要提供模型的名字,它會去查看checkpoint文件,看看最新的是誰,叫作什麼,而後載入變量3d