【導讀】<br /> 隨着TensorFlow的普及,愈來愈多的行業但願將Github中大量已有的TensorFlow代碼和模型集成到本身的業務系統中,如何在常見的編程語言(Java、NodeJS等)中使用TensorFlow成爲了一個比較常見的問題。專知成員Hujun給你們詳細介紹了在Java中使用TensorFlow的兩種方法,並着重介紹如何用TensorFlow官方Java API調用已有TensorFlow模型的方法。java
1. 直接使用TensorFlow官方API調用訓練好的pb模型node
2. (推薦) 使用KerasServer託管TensorFlow/Keras代碼及模型:python
雖然使用TensorFlow官方Java API能夠直接對接訓練好的pb模型,但在實際使用中,依然存在着與跨語種對接相關的繁瑣代碼。例如雖然已有使用Python編寫好的基於TensorFlow的文本分類代碼,但TensorFlow Java API的輸入須要是量化的文本,這樣咱們又須要用Java從新實如今Python代碼中已經實現的分詞、從字符串到索引的轉換等預處理操做(這些操做同時依賴於Python代碼依賴的單詞表等數據)。另外,因爲Java沒有numpy支持,在構建多維數組做爲輸入時,使用的依然是相似循環的操做,很是繁瑣。c++
KerasServer支持restful交互,所以能夠支持用任何程序語言調用TensorFlow/ Keras。因爲KerasServer的服務端提供Python API, 所以能夠直接將已有的TensorFlow/Keras Python代碼和模型轉換爲KerasServer API,供Java/c/c++/C#/ Python/ NodeJS/Browser Javascript等調用,而不須要再其餘語種中進行繁瑣的數據預處理操做。git
例如,Java可直接將須要分類的文本數據提交給KerasServer,KerasServer可利用已有的Python代碼對字符串進行分詞、預處理等操做。github
本教程介紹如何用TensorFlow官方Java API調用TensorFlow(Python)訓練好的模型。教程的代碼可在專知的Github項目中找到: https://github.com/ZhuanZhiCode/TensorFlow-Java-Examplesapache
#coding=utf-8 import tensorflow as tf # 定義圖 x = tf.placeholder(tf.float32, name="x") y = tf.get_variable("y", initializer=10.0) z = tf.log(x + y, name="z") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 進行一些訓練代碼,此處省略 # xxxxxxxxxxxx # 顯示圖中的節點 print([n.name for n in sess.graph.as_graph_def().node]) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["z"]) # 保存圖爲pb文件 with open('model.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString())
通過上面的代碼,pb模型成功保存了。接下來將使用java將該模型加載並運行起來編程
此處使用的是maven項目,因此須要先將tensorflow依賴加載上去api
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.5.0</version> </dependency> ....
模型的執行與Python相似,依然是導入圖,創建Session,指定輸入(feed)和輸出(fetch)。數組
import org.apache.commons.io.IOUtils; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import java.io.FileInputStream; import java.io.IOException; public class DemoImportGraph { public static void main(String[] args) throws IOException { try (Graph graph = new Graph()) { //導入圖 byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model.pb")); graph.importGraphDef(graphBytes); //根據圖創建Session try(Session session = new Session(graph)){ //至關於TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0}) float z = session.runner() //此處的'x'是模型的輸入;'z'是模型的輸出 .feed("x", Tensor.create(10.0f)) .fetch("z").run().get(0).floatValue(); System.out.println("z = " + z); } } } }
#tensorflow模型: import tensorflow as tf import os from tensorflow.python.framework import graph_util path = './model/' with tf.Session(graph=tf.Graph()) as sess: x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 這裏的輸出須要加上name屬性 op = tf.add(xy, b, name='op_to_store') sess.run(tf.global_variables_initializer()) # convert_variables_to_constants 須要指定output_node_names,list(),能夠多個 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) # 測試 OP feed_dict = {x: 10, y: 3} print(sess.run(op, feed_dict)) # 寫入序列化的 PB 文件 with tf.gfile.FastGFile(path+'model_3.pb', mode='wb') as f: f.write(constant_graph.SerializeToString())
//java代碼以下: import org.apache.commons.io.IOUtils; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import java.io.FileInputStream; import java.io.IOException; /** * Created on 2019-07-03 * * @author :hao.li */ public class reload_3 { public static void main(String[] args) throws IOException { try (Graph graph = new Graph()) { //導入圖 byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("/Users/lixuewei/workspace/private/tensorflow-java/src/main/resources/model_3.pb")); graph.importGraphDef(graphBytes); //根據圖創建Session try(Session session = new Session(graph)){ //至關於TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0}) Tensor<?> tensor = session.runner() .feed("x", Tensor.create(10)) .feed("y", Tensor.create(3)) .fetch("op_to_store").run().get(0); System.out.println(tensor.intValue()); } } } }
import tensorflow as tf import numpy as np import os tf.app.flags.DEFINE_integer('training_iteration', 302, 'number of training iterations.') tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the model.') tf.app.flags.DEFINE_string('work_dir', 'model/', 'Working directory.') FLAGS = tf.app.flags.FLAGS sess = tf.InteractiveSession() x = tf.placeholder('float', shape=[None, 3],name="x") y_ = tf.placeholder('float', shape=[None, 1]) w = tf.get_variable('w', shape=[3, 1], initializer=tf.truncated_normal_initializer) b = tf.get_variable('b', shape=[1], initializer=tf.zeros_initializer) sess.run(tf.global_variables_initializer()) y = tf.add(tf.matmul(x, w) , b,name="y") ms_loss = tf.reduce_mean((y - y_) ** 2) train_step = tf.train.GradientDescentOptimizer(0.005).minimize(ms_loss) train_x = np.random.randn(1000, 3) # let the model learn the equation of y = x1 * 1 + x2 * 2 + x3 * 3 train_y = np.sum(train_x * np.array([1, 2, 3]) + np.random.randn(1000, 3) / 100, axis=1).reshape(-1, 1) train_loss = [] for i in range(FLAGS.training_iteration): loss, _ = sess.run([ms_loss, train_step], feed_dict={x: train_x, y_: train_y}) train_loss.append(loss) export_path_base = FLAGS.work_dir export_path = os.path.join( tf.compat.as_bytes(export_path_base), tf.compat.as_bytes(str(FLAGS.model_version))) print('Exporting trained model to', export_path) # SavedModelBuilder裏面放的是保存模型的路徑,以下的export_path builder = tf.saved_model.builder.SavedModelBuilder(export_path) tensor_info_x = tf.saved_model.utils.build_tensor_info(x) tensor_info_y = tf.saved_model.utils.build_tensor_info(y) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'input': tensor_info_x}, outputs={'output': tensor_info_y}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') #第二步必須要有,它是給你的模型貼上一個標籤,這樣再次調用的時候就能夠根據標籤來找。我給它起的標籤名是"serve" builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'prediction': prediction_signature, }, legacy_init_op=legacy_init_op) builder.save() print('Training error %g' % loss) print('Done exporting!') print('Done training!')
import org.tensorflow.SavedModelBundle; public class TensorflowUtils { public static SavedModelBundle loadmodel(String modelpath){ SavedModelBundle bundle=SavedModelBundle.load(modelpath,"serve"); return bundle; } }
Main:
import org.tensorflow.SavedModelBundle; import org.tensorflow.Tensor; import java.util.Arrays; public class Model { SavedModelBundle bundle = null; public void init(){ String classpath=this.getClass().getResource("/").getPath()+"1" ; bundle=TensorflowUtils.loadmodel(classpath); } public double getResult(float[][] arr){ Tensor tensor=Tensor.create(arr); Tensor<?> result= bundle.session().runner().feed("x",tensor).fetch("y").run().get(0); float[][] resultValues = (float[][])result.copyTo(new float[1][1]); result.close(); return resultValues[0][0]; } public static void main(String[] args){ Model model =new Model(); model.init(); float[][] arr=new float[1][3]; arr[0][0]=1f; arr[0][1]=0.5f; arr[0][2]=2.0f; System.out.println(model.getResult(arr)); System.out.println(Arrays.toString("他".getBytes())); } }
Cannot find TensorFlow native library for OS: darwin, architecture: x86_64.
<br /> result:<br />
具體使用見代碼:github
該flink與tensorflow相結合的方式,有如下幾個問題: