[原]數據科學教程: 如何使用 mlflow 管理數據科學工做流

背景

近年來,人工智能與數據科學領域發展迅速,傳統項目在演化中也愈來愈複雜了,如何管理大量的機器學習項目成爲一個難題。html

在真正的機器學習項目中,咱們須要在模型以外花費大量的時間。好比:python

  • 跟蹤實驗效果

機器學習算法有可配置的超參一般都是十幾個到幾十個不等,如何跟蹤這些參數、代碼以及數據在每一個實驗中的表現目前業界也沒有一個統一的解決方案,更多都是根據某個實驗進行單獨的開發。git

  • 部署ML模型

部署ML模型一般都須要將模型文件和線上環境Service/Spark Job/SDK(Java/Scala/C++)對接,而大部分數據科學家一般都不太熟悉這些工程開發語言。所以,將模型遷移到不一樣平臺是具備挑戰性的,它意味着數據科學家還須要考慮線上部署的性能問題,目前業界也缺乏比較通用的模型部署工具。github

目前,在大廠內部已經孵化出這樣的一些機器學習平臺,好比 Uber 的 Michelangelo、Google 的 TFX,可是他們都與大廠的基礎架構深度耦合,因此也沒有在開源社區流行起來。算法

在這樣的背景下, mlflow 框架橫空出世,它的出現旨在將效果追蹤、模型調參、模型訓練、模型代碼、模型發佈等模塊集中一處,提高數據科學工做流的管理效率。數據庫

clipboard.png

簡介

mlflow 將數據科學工做流分爲3個部分:json

  1. 模型追蹤:支持記錄和查詢實驗周圍的數據,如評估指標和參數
  2. 項目管理:如何將模型封裝在 pipeline 中,以便與可重複執行
  3. 模型發佈:管理模型部署並提供 RestFul API

clipboard.png

模型追蹤:

mlflow tracking 提供了一個入口,用於將機器學習的參數、代碼版本、代碼路徑、評估指標等統一管理,輸出到系統中可視化管理。一般咱們模型會迭代不少次,這樣每次輸出的結果就能夠集中對比效果的好壞。segmentfault

好比:api

library(mlflow)

# 記錄超參
my_int <- mlflow_param("my_int", 1, "integer")
my_num <- mlflow_param("my_num", 1.0, "numeric")

# 記錄指標
mlflow_log_metric("accuracy", 0.45)

# 記錄輸出文件(模型、feature importance圖)等
mlflow_log_atrifact("roc.png")
mlflow_log_artifact("model.pkl")

項目管理

clipboard.png

mlflow project 提供了打包可重用數據科學代碼的標準格式,項目經過本地文件/git管理代碼,經過 yaml 文件來描述。瀏覽器

name: FinanceR Project
conda_env: conda.yaml
entry_points:
main:
parameters:
data_file: path
regularization: {type: double, default: 0.1}
command: "python train.py -r {regularization} {data_file}"
validate:
parameters:
data_file: path
command: "python validate.py {data_file}"

codna 將提供統一的虛擬環境服務,經過 mlflow run 能夠任意執行項目的 pipeline

mlflow run example/project -P num_dimensions=5

mlflow run git@github.com:xxx/xxx.git -P num_dimensions=5

下面舉一個官網的具體例子:

舉例

初始化

devtools::install_github("mlflow/mlflow", subdir = "mlflow/R/mlflow")
mlflow::mlflow_install()

模型參數

# Sample R code showing logging model parameters
library(mlflow)

# Define parameters
my_int <- mlflow_param("my_int", 1, "integer")
my_num <- mlflow_param("my_num", 1.0, "numeric")
my_str <- mlflow_param("my_str", "a", "string")

# Log parameters
mlflow_log_param("param_int", my_int)
mlflow_log_param("param_num", my_num)
mlflow_log_param("param_str", my_str)

模型訓練

# Sample R code training a linear model
library(mlflow)

# Read parameters
column <- mlflow_log_param("column", 1)

# Log total rows
mlflow_log_metric("rows", nrow(iris))

# Train model
model <- lm(Sepal.Width ~ iris[[column]], iris)

# Log models intercept
mlflow_log_metric("intercept", model$coefficients[["(Intercept)"]])

線上實驗

library(mlflow)
# Create and activate the 「R-Test」 experiment
mlflow_create_experiment("R-Test")

mlflow_active_run()

啓動界面

mlflow_ui()

默認須要在瀏覽器中訪問 localhost:5000

clipboard.png

添加註釋

超參調優

clipboard.png

超參調優支持3種模式:

  • Random: 徹底隨機探索策略
  • Gpyopt: 基於高斯過程的探索策略
  • Hyperopt: 基於數據庫的分佈式探索方法
mlflow run -e random --experiment-id <hyperparam_experiment_id>  -P \
    training_experiment_id=<individual_runs_experiment_id> examples/r_wine --entry-point train.R

其中 train.R 爲

library(mlflow)

# read parameters
column <- mlflow_log_param("column", 1)

# log total rows
mlflow_log_metric("rows", nrow(iris))

# train model
model <- lm(
  Sepal.Width ~ x,
  data.frame(Sepal.Width = iris$Sepal.Width, x = iris[,column])
)

# log models intercept
mlflow_log_metric("intercept", model$coefficients[["(Intercept)"]])

# save model
mlflow_save_model(
  crate(~ stats::predict(model, .x), model)
)

模型部署

mlflow rfunc serve model

模型推斷

mlflow_rfunc_predict("model", data = data.frame(x = c(0.3, 0.2)))
## Warning in mlflow_snapshot_warning(): Running without restoring the
## packages snapshot may not reload the model correctly. Consider running
## 'mlflow_restore_snapshot()' or setting the 'restore' parameter to 'TRUE'.

## 3.400381396714573.40656987651099

##        1        2 
## 3.400381 3.406570

或者在命令行中調用

mlflow rfunc predict model data.json

總結

mlflow 的出現極大方便了煉丹師傅們的工做,提供了堪比 michelangelo 的用戶體驗,而且全面支持 sklearn、spark、pytorch、tensorflow、mxnet、mlr、xgboost、keras 等主流算法框架。更多 mlflow 的詳細資料能夠參見官方文檔

參考資料

做爲分享主義者(sharism),本人全部互聯網發佈的圖文均聽從CC版權,轉載請保留做者信息並註明做者 Harry Zhu 的 FinanceR專欄: https://segmentfault.com/blog...,若是涉及源代碼請註明GitHub地址: https://github.com/harryprince。微信號: harryzhustudio 商業使用請聯繫做者。
相關文章
相關標籤/搜索