在用PMML實現機器學習模型的跨平臺上線中,咱們討論了使用PMML文件來實現跨平臺模型上線的方法,這個方法固然也適用於tensorflow生成的模型,可是因爲tensorflow模型每每較大,使用沒法優化的PMML文件大多數時候很笨拙,所以本文咱們專門討論下tensorflow機器學習模型的跨平臺上線的方法。html
tensorflow模型的跨平臺上線的備選方案通常有三種:即PMML方式,tensorflow serving方式,以及跨語言API方式。java
PMML方式的主要思路在上一篇以及講過。這裏惟一的區別是轉化生成PMML文件須要用一個Java庫jpmml-tensorflow來完成,生成PMML文件後,跨語言加載模型和其餘PMML模型文件基本相似。git
tensorflow serving是tensorflow 官方推薦的模型上線預測方式,它須要一個專門的tensorflow服務器,用來提供預測的API服務。若是你的模型和對應的應用是比較大規模的,那麼使用tensorflow serving是比較好的使用方式。可是它也有一個缺點,就是比較笨重,若是你要使用tensorflow serving,那麼須要本身搭建serving集羣並維護這個集羣。因此爲了一個小的應用去作這個工做,有時候會以爲麻煩。github
跨語言API方式是本文要討論的方式,它會用tensorflow本身的Python API生成模型文件,而後用tensorflow的客戶端庫好比Java或C++庫來作模型的在線預測。下面咱們會給一個生成生成模型文件並用tensorflow Java API來作在線預測的例子。算法
咱們這裏給一個簡單的邏輯迴歸並生成邏輯迴歸tensorflow模型文件的例子。服務器
完整代碼參見個人github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/tensorflow-java機器學習
首先,咱們生成了一個6特徵,3分類輸出的4000個樣本數據。maven
import numpy as np import matplotlib.pyplot as plt %matplotlib inline from sklearn.datasets.samples_generator import make_classification import tensorflow as tf X1, y1 = make_classification(n_samples=4000, n_features=6, n_redundant=0, n_clusters_per_class=1, n_classes=3)
接着咱們構建tensorflow的數據流圖,這裏要注意裏面的兩個名字,第一個是輸入x的名字input,第二個是輸出prediction_labels的名字output,這裏的這兩個名字能夠本身取,可是後面會用到,因此要保持一致。post
learning_rate = 0.01 training_epochs = 600 batch_size = 100 x = tf.placeholder(tf.float32, [None, 6],name='input') # 6 features y = tf.placeholder(tf.float32, [None, 3]) # 3 classes W = tf.Variable(tf.zeros([6, 3])) b = tf.Variable(tf.zeros([3])) # softmax迴歸 pred = tf.nn.softmax(tf.matmul(x, W) + b, name="softmax") cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) prediction_labels = tf.argmax(pred, axis=1, name="output") init = tf.global_variables_initializer()
接着就是訓練模型了,代碼比較簡單,畢竟只是一個演示:學習
sess = tf.Session() sess.run(init) y2 = tf.one_hot(y1, 3) y2 = sess.run(y2) for epoch in range(training_epochs): _, c = sess.run([optimizer, cost], feed_dict={x: X1, y: y2}) if (epoch+1) % 10 == 0: print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c)) print ("優化完畢!") correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y2, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) acc = sess.run(accuracy, feed_dict={x: X1, y: y2}) print (acc)
打印輸出我這裏就不寫了,你們能夠本身去試一試。接着就是關鍵的一步,存模型文件了,注意要用convert_variables_to_constants這個API來保存模型,不然模型參數不會隨着模型圖一塊兒存下來。
graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"]) tf.train.write_graph(graph, '.', 'rf.pb', as_text=False)
至此,咱們的模型文件rf.pb已經被保存下來了,下面就是要跨平臺上線了。
這裏咱們以Java平臺的模型上線爲例,C++的API上線我沒有用過,這裏就不寫了。咱們須要引入tensorflow的java庫到咱們工程的maven或者gradle文件。這裏給出maven的依賴以下,版本能夠根據實際狀況選擇一個較新的版本。
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.7.0</version> </dependency>
接着就是代碼了,這個代碼會比JPMML的要簡單,我給出了4個測試樣本的預測例子以下,必定要注意的是裏面的input和output要和訓練模型的時候對應的節點名字一致。
import org.tensorflow.*; import org.tensorflow.Graph; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; /** * Created by 劉建平pinard on 2018/7/1. */ public class TFjavaDemo { public static void main(String args[]){ byte[] graphDef = loadTensorflowModel("D:/rf.pb"); float inputs[][] = new float[4][6]; for(int i = 0; i< 4; i++){ for(int j =0; j< 6;j++){ if(i<2) { inputs[i][j] = 2 * i - 5 * j - 6; } else{ inputs[i][j] = 2 * i + 5 * j - 6; } } } Tensor<Float> input = covertArrayToTensor(inputs); Graph g = new Graph(); g.importGraphDef(graphDef); Session s = new Session(g); Tensor result = s.runner().feed("input", input).fetch("output").run().get(0); long[] rshape = result.shape(); int rs = (int) rshape[0]; long realResult[] = new long[rs]; result.copyTo(realResult); for(long a: realResult ) { System.out.println(a); } } static private byte[] loadTensorflowModel(String path){ try { return Files.readAllBytes(Paths.get(path)); } catch (IOException e) { e.printStackTrace(); } return null; } static private Tensor<Float> covertArrayToTensor(float inputs[][]){ return Tensors.create(inputs); } }
個人預測輸出是1,1,0,0,供你們參考。
對於tensorflow來講,模型上線通常選擇tensorflow serving或者client API庫來上線,前者適合於較大的模型和應用場景,後者則適合中小型的模型和應用場景。所以算法工程師使用在產品以前須要作好選擇和評估。
(歡迎轉載,轉載請註明出處。歡迎溝通交流: liujianping-ok@163.com)