tensorflow與流計算結合使用:tf+java+flink

tensorflow與java結合

【導讀】<br /> 隨着TensorFlow的普及,愈來愈多的行業但願將Github中大量已有的TensorFlow代碼和模型集成到本身的業務系統中,如何在常見的編程語言(Java、NodeJS等)中使用TensorFlow成爲了一個比較常見的問題。專知成員Hujun給你們詳細介紹了在Java中使用TensorFlow的兩種方法,並着重介紹如何用TensorFlow官方Java API調用已有TensorFlow模型的方法。java

1.Java調用TensorFlow的方法:<br />

  • 使用Java調用TensorFlow大體有兩種方法:

1. 直接使用TensorFlow官方API調用訓練好的pb模型node

2. (推薦) 使用KerasServer託管TensorFlow/Keras代碼及模型:python

雖然使用TensorFlow官方Java API能夠直接對接訓練好的pb模型,但在實際使用中,依然存在着與跨語種對接相關的繁瑣代碼。例如雖然已有使用Python編寫好的基於TensorFlow的文本分類代碼,但TensorFlow Java API的輸入須要是量化的文本,這樣咱們又須要用Java從新實如今Python代碼中已經實現的分詞、從字符串到索引的轉換等預處理操做(這些操做同時依賴於Python代碼依賴的單詞表等數據)。另外,因爲Java沒有numpy支持,在構建多維數組做爲輸入時,使用的依然是相似循環的操做,很是繁瑣。c++

KerasServer支持restful交互,所以能夠支持用任何程序語言調用TensorFlow/ Keras。因爲KerasServer的服務端提供Python API, 所以能夠直接將已有的TensorFlow/Keras Python代碼和模型轉換爲KerasServer API,供Java/c/c++/C#/ Python/ NodeJS/Browser Javascript等調用,而不須要再其餘語種中進行繁瑣的數據預處理操做。git

例如,Java可直接將須要分類的文本數據提交給KerasServer,KerasServer可利用已有的Python代碼對字符串進行分詞、預處理等操做。github

本教程介紹如何用TensorFlow官方Java API調用TensorFlow(Python)訓練好的模型。教程的代碼可在專知的Github項目中找到: https://github.com/ZhuanZhiCode/TensorFlow-Java-Examplesapache

2. tensorflow模型保存及java調用該模型

tensorflow模型及其保存

  • 首先:本地須要安裝好tensorflow相關的軟件。
  • 而後編寫模型代碼
#coding=utf-8
import tensorflow as tf


# 定義圖
x = tf.placeholder(tf.float32, name="x")
y = tf.get_variable("y", initializer=10.0)
z = tf.log(x + y, name="z")

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 進行一些訓練代碼,此處省略
    # xxxxxxxxxxxx

    # 顯示圖中的節點
    print([n.name for n in sess.graph.as_graph_def().node])
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names=["z"])

    # 保存圖爲pb文件
    with open('model.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

通過上面的代碼,pb模型成功保存了。接下來將使用java將該模型加載並運行起來編程

java加載模型

此處使用的是maven項目,因此須要先將tensorflow依賴加載上去api

<dependency>
   <groupId>org.tensorflow</groupId>
   <artifactId>tensorflow</artifactId>
   <version>1.5.0</version>
</dependency>
....

模型的執行與Python相似,依然是導入圖,創建Session,指定輸入(feed)和輸出(fetch)。數組

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.FileInputStream;
import java.io.IOException;

public class DemoImportGraph {

    public static void main(String[] args) throws IOException {
        try (Graph graph = new Graph()) {
            //導入圖
            byte[] graphBytes = IOUtils.toByteArray(new 
            FileInputStream("model.pb"));
            graph.importGraphDef(graphBytes);

            //根據圖創建Session
            try(Session session = new Session(graph)){
                //至關於TensorFlow Python中的sess.run(z, 
feed_dict = {'x': 10.0})
                float z = session.runner()
                //此處的'x'是模型的輸入;'z'是模型的輸出
                        .feed("x", Tensor.create(10.0f))
                        .fetch("z").run().get(0).floatValue();
                System.out.println("z = " + z);
            }
        }

    }
}

例子2(該模型有兩個輸入)

#tensorflow模型:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

path = './model/'
with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 這裏的輸出須要加上name屬性
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    # convert_variables_to_constants 須要指定output_node_names,list(),能夠多個
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

    # 測試 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))

    # 寫入序列化的 PB 文件
    with tf.gfile.FastGFile(path+'model_3.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())
//java代碼以下:
import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.FileInputStream;
import java.io.IOException;

/**
 * Created on 2019-07-03
 *
 * @author :hao.li
 */
public class reload_3 {
    public static void main(String[] args) throws IOException {
        try (Graph graph = new Graph()) {
            //導入圖
            byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("/Users/lixuewei/workspace/private/tensorflow-java/src/main/resources/model_3.pb"));
            graph.importGraphDef(graphBytes);

            //根據圖創建Session
            try(Session session = new Session(graph)){
                //至關於TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0})
                Tensor<?> tensor = session.runner()
                        .feed("x", Tensor.create(10))
                        .feed("y", Tensor.create(3))
                        .fetch("op_to_store").run().get(0);
                System.out.println(tensor.intValue());
            }
        }

    }
}

3. TensorFlow serving加載的模型格式在java中直接加載

  1. python代碼
import tensorflow as  tf
import  numpy as np
import  os
tf.app.flags.DEFINE_integer('training_iteration', 302,
'number of training iterations.')
tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the model.')
tf.app.flags.DEFINE_string('work_dir', 'model/', 'Working directory.')
FLAGS = tf.app.flags.FLAGS
sess = tf.InteractiveSession()
x = tf.placeholder('float', shape=[None, 3],name="x")
y_ = tf.placeholder('float', shape=[None, 1])
w = tf.get_variable('w', shape=[3, 1], initializer=tf.truncated_normal_initializer)
b = tf.get_variable('b', shape=[1], initializer=tf.zeros_initializer)
 
sess.run(tf.global_variables_initializer())
 
y = tf.add(tf.matmul(x, w) , b,name="y")
 
ms_loss = tf.reduce_mean((y - y_) ** 2)
 
train_step = tf.train.GradientDescentOptimizer(0.005).minimize(ms_loss)
 
train_x = np.random.randn(1000, 3)
# let the model learn the equation of y = x1 * 1 + x2 * 2 + x3 * 3
train_y = np.sum(train_x * np.array([1, 2, 3]) + np.random.randn(1000, 3) / 100, axis=1).reshape(-1, 1)
 
train_loss = []
for i in range(FLAGS.training_iteration):
    loss, _ = sess.run([ms_loss, train_step], feed_dict={x: train_x, y_: train_y})
    train_loss.append(loss)
export_path_base = FLAGS.work_dir
export_path = os.path.join(
    tf.compat.as_bytes(export_path_base),
    tf.compat.as_bytes(str(FLAGS.model_version)))
print('Exporting trained model to', export_path)
# SavedModelBuilder裏面放的是保存模型的路徑,以下的export_path
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
 
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
 
prediction_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'input': tensor_info_x},
        outputs={'output': tensor_info_y},
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
 
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
#第二步必須要有,它是給你的模型貼上一個標籤,這樣再次調用的時候就能夠根據標籤來找。我給它起的標籤名是"serve"
builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
        'prediction':
            prediction_signature,
    },
    legacy_init_op=legacy_init_op)
 
builder.save()
print('Training error %g' % loss)
 
print('Done exporting!')
 
print('Done training!')
  1. java代碼: TensorflowUtils:
import org.tensorflow.SavedModelBundle;
public class TensorflowUtils {
    public  static SavedModelBundle  loadmodel(String modelpath){
         SavedModelBundle bundle=SavedModelBundle.load(modelpath,"serve");
         return bundle;
    }
}

Main:

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import java.util.Arrays;
 
public class Model {
 
    SavedModelBundle bundle = null;
    public void init(){
        String  classpath=this.getClass().getResource("/").getPath()+"1" ;
        bundle=TensorflowUtils.loadmodel(classpath);
    }
 
    public  double  getResult(float[][] arr){
        Tensor  tensor=Tensor.create(arr);
        Tensor<?>  result= bundle.session().runner().feed("x",tensor).fetch("y").run().get(0);
        float[][] resultValues = (float[][])result.copyTo(new float[1][1]);
        result.close();
        return resultValues[0][0];
    }
    public static void main(String[] args){
        Model  model =new Model();
        model.init();
        float[][] arr=new  float[1][3];
        arr[0][0]=1f;
        arr[0][1]=0.5f;
        arr[0][2]=2.0f;
        System.out.println(model.getResult(arr));
        System.out.println(Arrays.toString("他".getBytes()));
        }
 }

github

踩坑

Cannot find TensorFlow native library for OS: darwin, architecture: x86_64.<br /> result:<br />

Flink實時調用tensorflow模型

具體使用見代碼:github

該flink與tensorflow相結合的方式,有如下幾個問題:

  1. 當併發數>1,須要考慮如何控制數據按照時間順序進入到多個task中進行預測?
  2. ....

參考

  1. 【乾貨】使用TensorFlow官方Java API調用TensorFlow模型(附代碼)
  2. Java調用Keras、Tensorflow模型
  3. tensorflow模型部署系列————單機java部署(附代碼)
  4. convert_variables_to_constants
  5. TensorFlow serving加載的模型格式在java中直接加載
  6. 在java中調用訓練好的TensorFlow模型
相關文章
相關標籤/搜索