【TensorFlow系列】【二】如何從ckpt文件中拷貝權值到新的變量中

在基於TensorFlow作fine-tuning或者遷移學習時,面臨的一個問題就是如何從已有的模型中,將其模型參數拷貝到自定義的新模型中。python

本文講述以下兩個問題:網絡

一、如何從ckpt模型文件中獲取權值的名字?session

二、如何將權值拷貝到新的變量中?app

 

具體見代碼註釋:函數

import tensorflow as tf

#從ckpt文件中獲取variable變量的名字
def get_trainable_variables_name_from_ckpt(meta_graph_path,ckpt_path):
    #定義一個新的graph
    graph = tf.Graph()
    #將其設置爲默認圖:
    with graph.as_default():
        with tf.Session() as session:
            #加載計算圖
            saver = tf.train.import_meta_graph(meta_graph_path)
            #加載模型到session中關聯的graph中,即將模型文件中的計算圖加載到這裏的graph中
            saver.restore(session,ckpt_path)
            v_names = []
            #獲取session所關聯的圖中可被訓練的variable
            #使用tf.trainable_variables()獲取variable時,只有在該函數前面定義的variable纔會被獲取到
            #在其後面定義不會被獲取到,
            for v in tf.trainable_variables():
                v_names.append(v)
            return v_names
#利用pywrap_tensorflow獲取ckpt文件中的全部變量,獲得的是variable名字與shape的一個map
from tensorflow.python import pywrap_tensorflow
def get_all_variables_name_from_ckpt(ckpt_path):
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    all_var = reader.get_variable_to_shape_map()
    #reader.get_variable_to_dtype_map()
    return all_var


#從cpkt文件中拷貝模型的參數到自定義的變量中
def copy_var_from_ckpt(session,dst_var_name,dst_var,ckpt_path,meta_graph_path):
    #定義一個新的graph
    graph = tf.Graph()
    #將其設置爲默認圖:
    with graph.as_default():
        with tf.Session() as sess:
            #加載計算圖
            saver = tf.train.import_meta_graph(meta_graph_path)
            #加載模型到session中關聯的graph中,即將模型文件中的計算圖加載到這裏的graph中
            saver.restore(sess,ckpt_path)
            v_names = []
            #獲取session所關聯的圖中可被訓練的variable
            #使用tf.trainable_variables()獲取variable時,只有在該函數前面定義的variable纔會被獲取到
            #在其後面定義不會被獲取到,
            for v in tf.trainable_variables():
                v_names.append(v)
            if dst_var_name in v_names:
                #獲取tensor
                tensor = graph.get_tensor_by_name(dst_var_name)
                #獲取tensor的值,即網絡中權值
                weight = sess.run(tensor)
                #拷貝權值,注意,須要使用dst_var所在的session
                #使用assign操做來拷貝dst_var是一個variable,weight是一個array
                session.run(dst_var.assign(weight))
相關文章
相關標籤/搜索