TensorFlow 同時調用多個預訓練好的模型

在某些任務中,咱們須要針對不一樣的狀況訓練多個不一樣的神經網絡模型,這時候,在測試階段,咱們就須要調用多個預訓練好的模型分別來進行預測。python

弄明白瞭如何調用單個模型,其實調用多個模型也就瓜熟蒂落。咱們只須要創建多個圖,而後每一個圖導入一個模型,再針對每一個圖建立一個會話,分別進行預測便可。web

import tensorflow as tf import numpy as np # 創建兩個 graph g1 = tf.Graph() g2 = tf.Graph() # 爲每一個 graph 建建立一個 session sess1 = tf.Session(graph=g1) sess2 = tf.Session(graph=g2) X_1 = None tst_1 = None yhat_1 = None X_2 = None tst_2 = None yhat_2 = None def load_model(sess): """ Loading the pre-trained model and parameters. """ global X_1, tst_1, yhat_1 with sess1.as_default(): with sess1.graph.as_default(): modelpath = r'F:/resnet/model/new0.25-0.35/' saver = tf.train.import_meta_graph(modelpath + 'model-10.meta') saver.restore(sess1, tf.train.latest_checkpoint(modelpath)) graph = tf.get_default_graph() X_1 = graph.get_tensor_by_name("X:0") tst_1 = graph.get_tensor_by_name("tst:0") yhat_1 = graph.get_tensor_by_name("tanh:0") print('Successfully load the model_1!') def load_model_2(): """ Loading the pre-trained model and parameters. """ global X_2, tst_2, yhat_2 with sess2.as_default(): with sess2.graph.as_default(): modelpath = r'F:/resnet/model/new0.25-0.352/' saver = tf.train.import_meta_graph(modelpath + 'model-10.meta') saver.restore(sess2, tf.train.latest_checkpoint(modelpath)) graph = tf.get_default_graph() X_2 = graph.get_tensor_by_name("X:0") tst_2 = graph.get_tensor_by_name("tst:0") yhat_2 = graph.get_tensor_by_name("tanh:0") print('Successfully load the model_2!') def test_1(txtdata): """ Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3). Test a single axample. Arg: txtdata: Array in C. Returns: The normal of a face. """ global X_1, tst_1, yhat_1 data = np.array(txtdata) data = data.reshape(-1, 41, 41, 41, 3) output = sess1.run(yhat_1, feed_dict={X_1: data, tst_1: True}) # (100, 3) output = output.reshape(-1, 1) ret = output.tolist() return ret def test_2(txtdata): """ Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3). Test a single axample. Arg: txtdata: Array in C. Returns: The normal of a face. """ global X_2, tst_2, yhat_2 data = np.array(txtdata) data = data.reshape(-1, 41, 41, 41, 3) output = sess2.run(yhat_2, feed_dict={X_2: data, tst_2: True}) # (100, 3) output = output.reshape(-1, 1) ret = output.tolist() return ret import tensorflow as tf import numpy as np # 創建兩個 graph g1 = tf.Graph() g2 = tf.Graph() # 爲每一個 graph 建建立一個 session sess1 = tf.Session(graph=g1) sess2 = tf.Session(graph=g2) X_1 = None tst_1 = None yhat_1 = None X_2 = None tst_2 = None yhat_2 = None def load_model(sess): """ Loading the pre-trained model and parameters. """ global X_1, tst_1, yhat_1 with sess1.as_default(): with sess1.graph.as_default(): modelpath = r'F:/resnet/model/new0.25-0.35/' saver = tf.train.import_meta_graph(modelpath + 'model-10.meta') saver.restore(sess1, tf.train.latest_checkpoint(modelpath)) graph = tf.get_default_graph() X_1 = graph.get_tensor_by_name("X:0") tst_1 = graph.get_tensor_by_name("tst:0") yhat_1 = graph.get_tensor_by_name("tanh:0") print('Successfully load the model_1!') def load_model_2(): """ Loading the pre-trained model and parameters. """ global X_2, tst_2, yhat_2 with sess2.as_default(): with sess2.graph.as_default(): modelpath = r'F:/resnet/model/new0.25-0.352/' saver = tf.train.import_meta_graph(modelpath + 'model-10.meta') saver.restore(sess2, tf.train.latest_checkpoint(modelpath)) graph = tf.get_default_graph() X_2 = graph.get_tensor_by_name("X:0") tst_2 = graph.get_tensor_by_name("tst:0") yhat_2 = graph.get_tensor_by_name("tanh:0") print('Successfully load the model_2!') def test_1(txtdata): """ Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3). Test a single axample. Arg: txtdata: Array in C. Returns: The normal of a face. """ global X_1, tst_1, yhat_1 data = np.array(txtdata) data = data.reshape(-1, 41, 41, 41, 3) output = sess1.run(yhat_1, feed_dict={X_1: data, tst_1: True}) # (100, 3) output = output.reshape(-1, 1) ret = output.tolist() return ret def test_2(txtdata): """ Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3). Test a single axample. Arg: txtdata: Array in C. Returns: The normal of a face. """ global X_2, tst_2, yhat_2 data = np.array(txtdata) data = data.reshape(-1, 41, 41, 41, 3) output = sess2.run(yhat_2, feed_dict={X_2: data, tst_2: True}) # (100, 3) output = output.reshape(-1, 1) ret = output.tolist() return ret

最後,本程序只是爲了說明問題,拋磚引玉,代碼有不少冗餘之處,不要模仿!網絡

獲取更多精彩,請關注「seniusen」!
seniusensession