在基於TensorFlow作fine-tuning或者遷移學習時,面臨的一個問題就是如何從已有的模型中,將其模型參數拷貝到自定義的新模型中。python
本文講述以下兩個問題:網絡
一、如何從ckpt模型文件中獲取權值的名字?session
二、如何將權值拷貝到新的變量中?app
具體見代碼註釋:函數
import tensorflow as tf #從ckpt文件中獲取variable變量的名字 def get_trainable_variables_name_from_ckpt(meta_graph_path,ckpt_path): #定義一個新的graph graph = tf.Graph() #將其設置爲默認圖: with graph.as_default(): with tf.Session() as session: #加載計算圖 saver = tf.train.import_meta_graph(meta_graph_path) #加載模型到session中關聯的graph中,即將模型文件中的計算圖加載到這裏的graph中 saver.restore(session,ckpt_path) v_names = [] #獲取session所關聯的圖中可被訓練的variable #使用tf.trainable_variables()獲取variable時,只有在該函數前面定義的variable纔會被獲取到 #在其後面定義不會被獲取到, for v in tf.trainable_variables(): v_names.append(v) return v_names #利用pywrap_tensorflow獲取ckpt文件中的全部變量,獲得的是variable名字與shape的一個map from tensorflow.python import pywrap_tensorflow def get_all_variables_name_from_ckpt(ckpt_path): reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path) all_var = reader.get_variable_to_shape_map() #reader.get_variable_to_dtype_map() return all_var #從cpkt文件中拷貝模型的參數到自定義的變量中 def copy_var_from_ckpt(session,dst_var_name,dst_var,ckpt_path,meta_graph_path): #定義一個新的graph graph = tf.Graph() #將其設置爲默認圖: with graph.as_default(): with tf.Session() as sess: #加載計算圖 saver = tf.train.import_meta_graph(meta_graph_path) #加載模型到session中關聯的graph中,即將模型文件中的計算圖加載到這裏的graph中 saver.restore(sess,ckpt_path) v_names = [] #獲取session所關聯的圖中可被訓練的variable #使用tf.trainable_variables()獲取variable時,只有在該函數前面定義的variable纔會被獲取到 #在其後面定義不會被獲取到, for v in tf.trainable_variables(): v_names.append(v) if dst_var_name in v_names: #獲取tensor tensor = graph.get_tensor_by_name(dst_var_name) #獲取tensor的值,即網絡中權值 weight = sess.run(tensor) #拷貝權值,注意,須要使用dst_var所在的session #使用assign操做來拷貝dst_var是一個variable,weight是一個array session.run(dst_var.assign(weight))