目前tf只能保存模型中的variable變量,整個模型還不能保存,版本1.xspa
保存模型代碼rest
import tensorflow as tf import numpy as np # Save to file # remember to define the same dtype and shape when restore v1 = tf.Variable(tf.constant(1.0,shape=[1]), name='v1') v2 = tf.Variable(tf.constant(2.0,shape=[1]), name='v2') result=v1+v2 # tf.initialize_all_variables() no long valid from # 2017-03-02 if using tensorflow >= 0.12 if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1: init = tf.initialize_all_variables() else: init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) save_path = saver.save(sess,"save_model/save_pp.ckpt") print("Save to path: ", save_path)
文件結構以下code
還原模型代碼blog
################################################ # restore variables # redefine the same shape and same type for your variables v1 = tf.Variable(tf.constant(1.0,shape=[1]), name='v1') v2 = tf.Variable(tf.constant(2.0,shape=[1]), name='v2') result=v1+v2 # not need init step saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "./save_model/save_pp.ckpt") print("v:", sess.run(v1)) print("result:", sess.run(result))
報錯信息rem
未解決it