機器學習在全球範圍內愈來愈受歡迎和使用。 它已經完全改變了某些應用程序的構建方式,而且可能會繼續成爲咱們平常生活中一個巨大的(而且正在增長的)部分。沒有什麼包裝且機器學習並不簡單。 它對許多人來講彷佛很是複雜並經常使人生畏。像谷歌這樣的公司將本身的機器學習概念與開發人員聯繫起來,在谷歌幫助下讓他們逐漸邁出第一步,故TensorFlow的框架誕生了。java
TensorFlow是由谷歌使用Python和C++開發的開源機器學習框架。它能夠幫助開發人員輕鬆獲取數據,準備和訓練模型,預測將來狀態,以及執行大規模機器學習。有了它,咱們能夠訓練和運行深度神經網絡的內容,諸如光學字符識別,圖像識別/分類,天然語言處理等。數組
TensorFlow基於計算圖,你能夠將其想象爲具備節點和邊的經典圖。每一個節點被稱爲操做,它們將零個或多個張量輸入併產生零個或多個張量輸出。 操做能夠很是簡單,例如基本的添加,但它們也能夠很是複雜。張量被描繪爲圖的邊緣,而且是核心數據單元。 當咱們將它們提供給操做時,咱們在這些張量上執行不一樣的功能。 它們能夠具備單個或多個維度,有時也稱爲它們的等級(標量:等級0,向量:等級1,矩陣:等級2)。這些數據受到操做的影響經過張量傳遞到計算圖中,故而稱爲TensorFlow。張量能夠以任意數量的維度存儲數據,而且有三種主要類型的張量:佔位符,變量和常量。安全
使用Maven,安裝TensorFlow就像包含依賴項同樣簡單:服務器
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.13.1</version>
</dependency>複製代碼
若是你的設備支持GPU功能,能夠添加如下依賴:網絡
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
<version>1.13.1</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.13.1</version>
</dependency>複製代碼
你可使用TensorFlow對象來檢查當前操做的TensorFlow的版本。session
System.out.println(TensorFlow.version());複製代碼
Java API TensorFlow提供包含在org.tensorflow包中。 它目前是實驗性的,所以不能保證其穩定性。須要注意的是TensorFlow惟一徹底支持的語言是Python,Java API幾乎沒有什麼功能。API向咱們介紹了新的類,接口,枚舉和異常。框架
經過API引入的新類是:機器學習
若是咱們將全部這些與Python中的tf模塊進行比較將發現存在明顯的區別。 Java API沒有幾乎相同的功能,至少目前如此。分佈式
如前所述,TensorFlow基於計算圖 - 其中org.tensorflow.Graph是Java的實現。注意:它的實例是線程安全的,儘管咱們須要在完成它以後顯式釋放Graph使用的資源。函數
讓咱們從一個空圖開始:
Graph graph = new Graph();複製代碼
該對象是空的,因此這個圖表意義不大。 要對它作任何操做,咱們首先須要使用Operations加載它。咱們使用opBuilder()方法來加載它,它返回一個OperationBuilder對象,一旦咱們調用.build()方法,它就會將操做添加到咱們的圖形中。
讓咱們在圖表中添加一個常量:
Operation x = graph.opBuilder("Const", "x")
.setAttr("dtype", DataType.FLOAT)
.setAttr("value", Tensor.create(3.0f))
.build(); 複製代碼
佔位符是變量的「類型」,聲明時沒有賦值,他們的值將在以後進行分配。 這容許咱們使用沒有任何實際數據的操做來構建圖形:
Operation y = graph.opBuilder("Placeholder", "y")
.setAttr("dtype", DataType.FLOAT)
.build();複製代碼
最後爲了解決這個問題,咱們須要添加某些函數。 這些能夠像乘法,除法或加法同樣簡單,也能夠像矩陣乘法同樣複雜。 和以前同樣,咱們使用.opBuilder()方法定義函數:
Operation xy = graph.opBuilder("Mul", "xy")
.addInput(x.output(0))
.addInput(y.output(0))
.build(); 複製代碼
注意:咱們使用input(0)做爲張量能夠有多個輸出。
遺憾的是,Java API尚未包含任何容許像Python中同樣可視化圖形的工具。
如前所述,Session是Graph的驅動程序。 它封裝了執行Operation和Graph計算張量(tensors)的環境。這意味着咱們構建的圖(graph)中的張量(tensors)實際上並無任何值,由於咱們沒有在會話(session)中運行圖形(graph)。咱們首先將圖表添加到會話(session)中:
Session session = new Session(graph);複製代碼
咱們的操做知識簡單地將x於y相乘,爲了運行咱們的圖(graph)並獲得計算結果,咱們須要使用fetch()獲取到xy的操做併爲其提供x和y的值:
Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);
System.out.println(tensor.floatValue());複製代碼
運行這段代碼將產生的結果以下:
10.0f複製代碼
這可能聽起來有點奇怪,但因爲Python是惟一受到良好支持的語言,所以Java API仍然沒有保存模型的功能。這意味着Java API僅用於服務用例,至少在TensorFlow徹底支持以前。 目前至少咱們可使用SavedModelBundle類在Python中訓練和保存模型,而後使用Java加載它們來爲它們提供服務:
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);
System.out.println(tensor.floatValue());複製代碼
TensorFlow是一個功能強大且普遍使用的框架。 它不斷獲得改進,並最近被引入新語言:包括Java和JavaScript。儘管Java API尚未像TensorFlow在Python中那麼多的功能,但它仍然能夠做爲向Java開發人員介紹TensorFlow的一個很好的開始。
原文連接:https://stackabuse.com/how-to-use-tensorflow-with-java/
做 者:David Landup
譯 者:klein
------
9月福利,關注公衆號後臺回覆:004,領取8月翻譯集錦!往期福利回覆:001,002, 003便可領取!