tensorflow中一種融合多個模型的方法

1.使用場景

假設咱們有訓練好的模型A,B,C,咱們但願使用A,B,C中的部分或者所有變量,合成爲一個模型D,用於初始化或其餘目的,就須要融合多個模型的方法dom

 

2.如何實現

咱們能夠先聲明模型D,再建立多個Saver實例,分別從模型A,B,C的保存文件(checkpoint文件)中讀取所需的變量值,來達成這一目的,下面是示例代碼:spa

首先建立一個只包含w1,w2兩個變量的模型,初始化後保存:rest

 1 def train_model1():
 2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(), trainable=True)
 4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 5     a1 = tf.matmul(x, w1)
 6     input = np.random.rand(3200, 3)
 7     sess = tf.InteractiveSession()
 8     sess.run(tf.global_variables_initializer())
 9     saver1 = tf.train.Saver([w1,w2])
10     for i in range(0, 1):
11         w1_var,w2_var = sess.run([w1,w2], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w1_var
13         print w2_var
14         print '=' * 30
15     saver1.save(sess, 'save1-exp')

而後再建立一個只包含w2,w3兩個變量的模型,也是初始化後保存:code

 1 def train_model2():
 2     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 5     a2 = tf.matmul(x, w2 * w3)
 6     input = np.random.rand(3200, 3)
 7     sess = tf.InteractiveSession()
 8     sess.run(tf.global_variables_initializer())
 9     saver2 = tf.train.Saver([w2,w3])
10     for i in range(0, 1):
11         w2_var, w3_var = sess.run([w2, w3], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w2_var
13         print w3_var
14         print '=' * 30
15     saver2.save(sess, 'save2-exp')

最後咱們建立一個包含w1,w2,w3變量的模型,從上面兩個保存的ckp文件中恢復:orm

 1 def restore_model():
 2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 4     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 5     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 6     a1 = tf.matmul(x, w1)
 7     a2 = tf.matmul(x, w2 * w3)
 8     loss = tf.reduce_mean(tf.square(a1 - a2))
 9     sess = tf.InteractiveSession()
10     sess.run(tf.global_variables_initializer())
11     saver1 = tf.train.Saver([w1,w2])
12     saver1.restore(sess, 'save1-exp')
13     saver2 = tf.train.Saver([w2, w3])
14     saver2.restore(sess, 'save2-exp')
15     saver3 = tf.train.Saver(tf.trainable_variables())
16     input = np.random.rand(3200, 3)
17     w1_var, w2_var, w3_var = sess.run([w1, w2, w3], feed_dict={x: input[0:32]})
18     print w1_var
19     print w2_var
20     print w3_var
21     print '=' * 30
22     saver3.save(sess, 'save3-exp')

而後保存,即完成了咱們的目標blog

 

3.注意事項

3.1 取的模型中有同名變量

假設同名變量爲a,這種狀況下,從不一樣模型中恢復的a是按照讀取順序覆蓋到a中的,若是但願只讀取特定ckpt保存的變量值,在建立讀取其餘ckpt的saver時,不要把a加入到var_list中get

3.2 模型D中有部分變量不在A,B,C中

這種狀況,恢復時會報錯,須要指定var_list,只恢復當前cpkt中保存的變量input

相關文章
相關標籤/搜索