當一個TensorFlow模型訓練出來的時候,爲了投入到實際應用,因此就須要部署到服務器上。因爲我本次所作的項目是一個javaweb的圖像識別項目。全部我就想去尋找一下java調用TensorFlow訓練模型的辦法。前端
因爲TensorFlow好久沒更新的緣故,網上的博客大都是18/19年的,而且是基於TensorFlow1.0的,對於如今使用的TensorFlow2.0不太友好。java
下面我簡述一下TensorFlow1.0時期的方法:python
須要將訓練的.h5模型轉換成.pb模型,而且須要本身定義.pb模型的輸入輸出參數。(pb模型是一種基於動態圖的模型)web
pb的生成代碼冗長、並且對初學者真滴不太友好json
相比之下.h5模型的生成代碼就一行flask
此外,這個生成pb模型的代碼是否能照搬使用,仍是一個問題,而且還可能報一些奇奇怪怪的錯誤。api
查閱資料發現java上的TensorFlow的jar包都是TensorFlow1.0的服務器
現狀:app
而且maven官網上的TensorFlow2.0的api已經更名成了tensorflow-core-api,而且網上相關方面的教程十分難找。因爲網上都是導入的1.0的包,本身導入2.0的包以後,詳細的調用教程能夠說是沒有。從上面也能夠看出來TensorFlow對java的調用也不怎麼重視了。因此這又給學習的途中徒增了不少困難。框架
用java直接調用訓練好的模型很困難,那麼咱們想辦法讓java調用python腳本,讓python腳本去調用.h5模型會不會更簡單呢?
代碼以下
package com.guard.service; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; public class api_service { public String recognize(String path){ //此處的path是圖片路徑 Process proc; String res = null; try { System.out.println("接受到的參數"+path); String[] cmd = new String[] { "python", "E:\\machine_learning\\predict.py", path}; proc = Runtime.getRuntime().exec(cmd); BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream())); String line = null; while ((line = in.readLine()) != null) { System.out.println(line); res = line; } in.close(); proc.waitFor(); } catch (IOException e) { e.printStackTrace(); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println(res+">>>>>>>>>>>"); return res; } }
可是咱們能夠看出,這個實際上是用java在win上跑了這樣一個指令
雖然這個確實是一個好辦法,可是這個路徑參數須要事先知道服務器上的路徑,而且在協做開發的時候,每一個人的路徑和環境就不一樣,雖然該方法能用,可是我認爲還不夠好。
咱們能夠直接用python的flask框架,直接生成一個api接口,就能夠遠程直接調用TensorFlow訓練好的模型進行結果預測。
我的認爲,這種方法相較於用java調用命令行,這種方法仍是更加直觀的
而且flask僅僅須要加個@app.route的註解就能實現,可謂是十分方便
下面是模型調用代碼
model.py
import glob import sys import os import cv2 import numpy as np import tensorflow as tf import image_processing def model_ues(path): # 縮放圖片大小爲100*100 w = 100 h = 100 # 測試圖像的地址 (改成本身的) # path_test = "resource/test24.jpg" api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda" path_test = image_processing.download_img(path,api_token) # 建立保存圖像的空列表 imgs = [] img = cv2.imread(path_test) img = cv2.resize(img, (w, h)) # 將每張通過處理的圖像數據保存在以前建立的imgs空列表當中 imgs.append(img) imgs = np.asarray(imgs, np.float32) # print("shape of data:",imgs.shape) # 導入模型 model = tf.keras.models.load_model(r"resource/rice_0.93.h5") # 建立圖像標籤列表 rice_dict = {0: 'Rice blast', 1: 'Rice fleck', 2: 'Rice koji disease', 3: 'Sheath blight'} # 將圖像導入模型進行預測 prediction = model.predict_classes(imgs) # prediction = np.argmax(model.predict(imgs), axis=-1) # 繪製預測圖像 for i in range(np.size(prediction)): # 打印每張圖像的預測結果 print(rice_dict[prediction[i]]) return rice_dict[prediction[0]]
爲了實現圖片外連接受,下面是圖片下載腳本
image_processing.py
# coding: utf8 import requests import random def download_img(img_url, api_token): print (img_url) header = {"Authorization": "Bearer " + api_token} # 設置http header,視狀況加須要的條目,這裏的token是用來鑑權的一種方式 r = requests.get(img_url, headers=header, stream=True) print(r.status_code) # 返回狀態碼 file_img = 'resource/img.png' # file_img = 'resource/' print(file_img) if r.status_code == 200: open(file_img, 'wb').write(r.content) # 將內容寫入圖片 print("done") del r return file_img # if __name__ == '__main__': # # 下載要的圖片 # img_url = "https://z3.ax1x.com/2021/07/27/W5l6Qe.png" # api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda" # download_img(img_url, api_token)
主程序腳本
app.py
from flask import Flask,render_template, url_for, request, json,jsonify import model app = Flask(__name__) #設置編碼 app.config['JSON_AS_ASCII'] = False @app.route('/test') def hello_world(): return "hello world" @app.route('/predict', methods=['GET', 'POST']) def form_data(): my_path = request.form['path'] print(my_path) str = model.model_ues(my_path) print("http://127.0.0.1:5000/predict") return jsonify({'result':str,'msg':'200'}) if __name__ == '__main__': app.run()
雖然咱們可以經過postman進行測試接受到回傳的結果,可是咱們要怎麼用java實現呢??
1.使用postman生成大體代碼框架(postman生成的代碼可能不能直接運行)
這裏我選用的是java-okhttp的方法,但其實使用Unirest寫出來的代碼更加簡潔易懂。
public class Get_result { public String getResult(String path) throws IOException { // String path = "https://i.loli.net/2021/07/29/badDNR2OCironUf.jpg"; OkHttpClient client = new OkHttpClient().newBuilder() .build(); MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); RequestBody body = RequestBody.create(mediaType, "path="+path); Request request = new Request.Builder() .url("http://127.0.0.1:8000/predict") .method("POST", body) .addHeader("Content-Type", "application/x-www-form-urlencoded") .build(); Response response = client.newCall(request).execute(); String result = response.body().string(); System.out.println(result); } }
{ "msg": "200", "result": "Rice fleck" }
獲取到json數據以後,就須要對json數據進行解析
java上的解析原理是,先按照json編寫一個類,以後用Gson對接受到的數據按照這個類進行規範化
(這裏能夠用GsonFormatPlus插件來自動生成這個實體類)
//Rice_result.java---爲該json的實體類 package com.guard.tool; import lombok.Data; import lombok.NoArgsConstructor; @NoArgsConstructor @Data public class Rice_result { private String msg; private String result; }
下面是數據解析代碼(和上面的okhttp獲取json數據的代碼連起來看)
//json數據解析 Gson gson = new Gson(); java.lang.reflect.Type type = new TypeToken<Rice_result>(){}.getType(); Rice_result rice_result = gson.fromJson(result, type); System.out.println(rice_result); if("200".equals(rice_result.getMsg())){ // System.out.println(rice_result.getResult()); return Rice_result.convertdata(rice_result.getResult()); }else { // System.out.println("獲取結果出錯!!"); return "獲取結果出錯!!"; }
這樣的話就能夠進行json數據的解析了。
因爲須要使用java發送post請求給flask的預測端口,那麼就須要把本地上傳的數據作成圖鏈,把圖鏈做爲數據傳給flask的預測端口,從而來接收結果。
因爲前端js的知識大多遺忘,這裏就選用了用java來發送一個post請求,得到回傳的信息。
這裏我使用的是sm.ms的圖牀(該圖牀無需登陸,且速度快,算得上是一個好的選擇)
//sm.ms的使用方法,建議看官方文檔 package com.guard.tool; import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; import okhttp3.*; import java.io.File; import java.io.IOException; public class CloudUpload { public String toUrl(String path) throws IOException { // String file_path = "E:/machine_learning/test8.jpg"; String file_path = path; OkHttpClient client = new OkHttpClient().newBuilder() .build(); MediaType mediaType = MediaType.parse("multipart/form-data"); RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM) .addFormDataPart("smfile",file_path, RequestBody.create(MediaType.parse("application/octet-stream"), new File(file_path))) .addFormDataPart("format","json") .build(); Request request = new Request.Builder() .url("https://sm.ms/api/v2/upload") .method("POST", body) .addHeader("Content-Type", "multipart/form-data") .addHeader("Authorization", "TlxzRSaVJj0o7HFZOd9sgdf4Jl60RA00") //這裏的user-agent和Cookie須要本身打開網站,到網站的頁面去拿取 .addHeader("user-agent","Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36") .addHeader("Cookie", "SMMSrememberme=42417%3A10e8e9cb5281082b493fdee73381aeb2dca0bd3d; PHPSESSID=1gjog2em3ogof23vrqi79vd41m; SM_FC=runWNk3mPIiL8mzl%2FrlEfzM940LRKjLm182cm2qDrm4%3D") .build(); Response response = client.newCall(request).execute(); String result = response.body().string(); System.out.println(result); // String result = response.body().string(); Gson gson = new Gson(); java.lang.reflect.Type type = new TypeToken<Image_data>(){}.getType(); Image_data imge_data = gson.fromJson(result, type); System.out.println(imge_data); if (imge_data.getSuccess()){ System.out.println(imge_data.getData().getUrl()); return imge_data.getData().getUrl(); } else{ System.out.println("圖片已經上傳過一次!!"); System.out.println(imge_data.getImages()); return imge_data.getImages(); } } }
回傳的json結果--這個就須要使用上面的插件來進行處理
{ "success": true, "code": "success", "message": "Upload success.", "data": { "file_id": 0, "width": 192, "height": 454, "filename": "test25.jpg", "storename": "xICPNzFsfth5uJk.png", "size": 124993, "path": "/2021/08/01/xICPNzFsfth5uJk.png", "hash": "2exIdQGvBru46RKMyNjg3DhCTO", "url": "https://i.loli.net/2021/08/01/xICPNzFsfth5uJk.png", "delete": "https://sm.ms/delete/2exIdQGvBru46RKMyNjg3DhCTO", "page": "https://sm.ms/image/xICPNzFsfth5uJk" }, "RequestId": "9BFE9DEB-8370-44C8-A8AF-AAB2DB753A18" }
以上就是我此次在小組編寫<基於CNN圖像分類的水稻病蟲害識別>這個項目中的收穫。在此記錄下學習路上踩過的一些坑和一些解決方法。