用PMML實現機器學習模型的跨平臺上線

    在機器學習用於產品的時候,咱們常常會遇到跨平臺的問題。好比咱們用Python基於一系列的機器學習庫訓練了一個模型,可是有時候其餘的產品和項目想把這個模型集成進去,可是這些產品不少只支持某些特定的生產環境好比Java,爲了上一個機器學習模型去大動干戈修改環境配置很不划算,此時咱們就能夠考慮用預測模型標記語言(Predictive Model Markup Language,如下簡稱PMML)來實現跨平臺的機器學習模型部署了。java

1. PMML概述

    PMML是數據挖掘的一種通用的規範,它用統一的XML格式來描述咱們生成的機器學習模型。這樣不管你的模型是sklearn,R仍是Spark MLlib生成的,咱們均可以將其轉化爲標準的XML格式來存儲。當咱們須要將這個PMML的模型用於部署的時候,可使用目標環境的解析PMML模型的庫來加載模型,並作預測。node

    能夠看出,要使用PMML,須要兩步的工做,第一塊是將離線訓練獲得的模型轉化爲PMML模型文件,第二塊是將PMML模型文件載入在線預測環境,進行預測。這兩塊都須要相關的庫支持。python

2. PMML模型的生成和加載相關類庫

    PMML模型的生成相關的庫須要看咱們使用的離線訓練庫。若是咱們使用的是sklearn,那麼可使用sklearn2pmml這個python庫來作模型文件的生成,這個庫安裝很簡單,使用"pip install sklearn2pmml"便可,相關的使用咱們後面會有一個demo。若是使用的是Spark MLlib, 這個庫有一些模型已經自帶了保存PMML模型的方法,惋惜並不全。若是是R,則須要安裝包"XML"和「PMML」。此外,JAVA庫JPMML能夠用來生成R,SparkMLlib,xgBoost,Sklearn的模型對應的PMML文件。github地址是:https://github.com/jpmml/jpmml。git

    加載PMML模型須要目標環境支持PMML加載的庫,若是是JAVA,則能夠用JPMML來加載PMML模型文件。相關的使用咱們後面會有一個demo。github

3. PMML模型生成和加載示例

    下面咱們給一個示例,使用sklearn生成一個決策樹模型,用sklearn2pmml生成模型文件,用JPMML加載模型文件,並作預測。算法

    完整代碼參見個人github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/sklearn-jpmml數組

    首先是用用sklearn生成一個決策樹模型,因爲咱們是須要保存PMML文件,因此最好把模型先放到一個Pipeline數組裏面。這個數組裏面除了咱們的決策樹模型之外,還能夠有歸一化,降維等預處理操做,這裏做爲一個示例,咱們Pipeline數組裏面只有決策樹模型。代碼以下:less

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
from sklearn import tree
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml

import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files/Java/jdk1.8.0_171/bin'

X=[[1,2,3,1],[2,4,1,5],[7,8,3,6],[4,8,4,7],[2,5,6,9]]
y=[0,1,0,2,1]
pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier(random_state=9))]);
pipeline.fit(X,y)

sklearn2pmml(pipeline, ".\demo.pmml", with_repr = True)

    上面這段代碼作了一個很是簡單的決策樹分類模型,只有5個訓練樣本,特徵有4個,輸出類別有3個。實際應用時,咱們須要將模型調參完畢後纔將其放入PMMLPipeline進行保存。運行代碼後,咱們在當前目錄會獲得一個PMML的XML文件,能夠直接打開看,內容大概以下:dom

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
    <Header>
        <Application name="JPMML-SkLearn" version="1.5.3"/>
        <Timestamp>2018-06-24T05:47:17Z</Timestamp>
    </Header>
    <MiningBuildTask>
        <Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=9,
            splitter='best'))])</Extension>
    </MiningBuildTask>
    <DataDictionary>
        <DataField name="y" optype="categorical" dataType="integer">
            <Value value="0"/>
            <Value value="1"/>
            <Value value="2"/>
        </DataField>
        <DataField name="x3" optype="continuous" dataType="float"/>
        <DataField name="x4" optype="continuous" dataType="float"/>
    </DataDictionary>
    <TransformationDictionary>
        <DerivedField name="double(x3)" optype="continuous" dataType="double">
            <FieldRef field="x3"/>
        </DerivedField>
        <DerivedField name="double(x4)" optype="continuous" dataType="double">
            <FieldRef field="x4"/>
        </DerivedField>
    </TransformationDictionary>
    <TreeModel functionName="classification" missingValueStrategy="nullPrediction" splitCharacteristic="multiSplit">
        <MiningSchema>
            <MiningField name="y" usageType="target"/>
            <MiningField name="x3"/>
            <MiningField name="x4"/>
        </MiningSchema>
        <Output>
            <OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
            <OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
            <OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/>
        </Output>
        <Node>
            <True/>
            <Node>
                <SimplePredicate field="double(x3)" operator="lessOrEqual" value="3.5"/>
                <Node score="1" recordCount="1.0">
                    <SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.0"/>
                    <ScoreDistribution value="0" recordCount="0.0"/>
                    <ScoreDistribution value="1" recordCount="1.0"/>
                    <ScoreDistribution value="2" recordCount="0.0"/>
                </Node>
                <Node score="0" recordCount="2.0">
                    <True/>
                    <ScoreDistribution value="0" recordCount="2.0"/>
                    <ScoreDistribution value="1" recordCount="0.0"/>
                    <ScoreDistribution value="2" recordCount="0.0"/>
                </Node>
            </Node>
            <Node score="2" recordCount="1.0">
                <SimplePredicate field="double(x4)" operator="lessOrEqual" value="8.0"/>
                <ScoreDistribution value="0" recordCount="0.0"/>
                <ScoreDistribution value="1" recordCount="0.0"/>
                <ScoreDistribution value="2" recordCount="1.0"/>
            </Node>
            <Node score="1" recordCount="1.0">
                <True/>
                <ScoreDistribution value="0" recordCount="0.0"/>
                <ScoreDistribution value="1" recordCount="1.0"/>
                <ScoreDistribution value="2" recordCount="0.0"/>
            </Node>
        </Node>
    </TreeModel>
</PMML>

    能夠看到裏面就是決策樹模型的樹結構節點的各個參數,以及輸入值。咱們的輸入被定義爲x1-x4,輸出定義爲y。機器學習

    有了PMML模型文件,咱們就能夠寫JAVA代碼來讀取加載這個模型並作預測了。

    咱們建立一個Maven或者gradle工程,加入JPMML的依賴,這裏給出maven在pom.xml的依賴,gradle的結構是相似的。

    <dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>pmml-evaluator</artifactId>
        <version>1.4.1</version>
    </dependency>
    <dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>pmml-evaluator-extension</artifactId>
        <version>1.4.1</version>
    </dependency>

    接着就是讀取模型文件並預測的代碼了,具體代碼以下:

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * Created by 劉建平Pinard on 2018/6/24.
 */
public class PMMLDemo {
    private Evaluator loadPmml(){
        PMML pmml = new PMML();
        InputStream inputStream = null;
        try {
            inputStream = new FileInputStream("D:/demo.pmml");
        } catch (IOException e) {
            e.printStackTrace();
        }
        if(inputStream == null){
            return null;
        }
        InputStream is = inputStream;
        try {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        } catch (SAXException e1) {
            e1.printStackTrace();
        } catch (JAXBException e1) {
            e1.printStackTrace();
        }finally {
            //關閉輸入流
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        pmml = null;
        return evaluator;
    }
    private int predict(Evaluator evaluator,int a, int b, int c, int d) {
        Map<String, Integer> data = new HashMap<String, Integer>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        List<InputField> inputFields = evaluator.getInputFields();
        //過模型的原始特徵,從畫像中獲取數據,做爲模型輸入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();

        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();

        Object targetFieldValue = results.get(targetFieldName);
        System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
        int primitiveValue = -1;
        if (targetFieldValue instanceof Computable) {
            Computable computable = (Computable) targetFieldValue;
            primitiveValue = (Integer)computable.getResult();
        }
        System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
        return primitiveValue;
    }
    public static void main(String args[]){
        PMMLDemo demo = new PMMLDemo();
        Evaluator model = demo.loadPmml();
        demo.predict(model,1,8,99,1);
        demo.predict(model,111,89,9,11);

    }
}

    代碼裏有兩個函數,第一個loadPmml是加載模型的,第二個predict是讀取預測樣本並返回預測值的。個人代碼運行結果以下:

target: y value: {result=2, probability_entries=[0=0.0, 1=0.0, 2=1.0], entityId=5, confidence_entries=[]}
1 8 99 1:2
target: y value: {result=1, probability_entries=[0=0.0, 1=1.0, 2=0.0], entityId=6, confidence_entries=[]}
111 89 9 11:1

    也就是樣本(1,8,99,1)被預測爲類別2,而(111,89,9,11)被預測爲類別1。

    以上就是PMML生成和加載的一個示例,使用起來其實門檻並不高,也很簡單。

4. PMML總結與思考

    PMML的確是跨平臺的利器,可是是否是就沒有缺點呢?確定是有的!

    第一個就是PMML爲了知足跨平臺,犧牲了不少平臺獨有的優化,因此不少時候咱們用算法庫本身的保存模型的API獲得的模型文件,要比生成的PMML模型文件小不少。同時PMML文件加載速度也比算法庫本身獨有格式的模型文件加載慢不少。

    第二個就是PMML加載獲得的模型和算法庫本身獨有的模型相比,預測會有一點點的誤差,固然這個誤差並不大。好比某一個樣本,用sklearn的決策樹模型預測爲類別1,可是若是咱們把這個決策樹落盤爲一個PMML文件,並用JAVA加載後,繼續預測剛纔這個樣本,有較小的機率出現預測的結果不爲類別1.

    第三個就是對於超大模型,好比大規模的集成學習模型,好比xgboost, 隨機森林,或者tensorflow,生成的PMML文件很容易獲得幾個G,甚至上T,這時使用PMML文件加載預測速度會很是慢,此時推薦爲模型創建一個專有的環境,就沒有必要去考慮跨平臺了。

    此外,對於TensorFlow,不推薦使用PMML的方式來跨平臺。可能的方法一是TensorFlow serving,本身搭建預測服務,可是會稍有些複雜。另外一個方法就是將模型保存爲TensorFlow的模型文件,並用TensorFlow獨有的JAVA庫加載來作預測。

    咱們在下一篇會討論用python+tensorflow訓練保存模型,並用tensorflow的JAVA庫加載作預測的方法和實例。

 

(歡迎轉載,轉載請註明出處。歡迎溝通交流: liujianping-ok@163.com)  

相關文章
相關標籤/搜索