以前寫過一篇TensorFlow Java 環境的搭建 TensorFlow Java+eclipse下環境搭建,今天看看TensorFlow Java API 的簡單說明 和操做。算法
由 Google 開源,是一個深度學習庫, 是一套使用數據流圖 (data flow graphics)進行數據計算的軟件庫(software library) 和應用接口(API),並以此做爲基礎加上其它功能的庫和開發工具成爲一套進行機器學習、特別是深度學習(deep learning)的應用程序開發框架 (framework)。 ---------------谷歌開發技術推广部 大中華區主管 欒躍 (Bill Luan)數組
支持CNN、RNN和LSTM算法,是目前在 Image,NLP (神經語言學)最流行的深度神經網絡模型。bash
基於Python,寫的很快而且具備可讀性。網絡
在多GPU系統上的運行更爲順暢。session
代碼編譯效率較高。框架
社區發展的很是迅速而且活躍。eclipse
可以生成顯示網絡拓撲結構和性能的可視化圖機器學習
TensorFlow是用數據流圖(data flow graphs)技術來進行數值計算的工具
邊:用於傳送節點之間的多維數組,即張量( tensor )post
節點:表示數學運算操做符 用operation表示,簡稱op
public class HelloTF {
public static void main(String[] args) throws Exception {
try (Graph g = new Graph(); Session s = new Session(g)) {
// 使用佔位符構造一個圖,添加兩個浮點型的張量
Output x = g.opBuilder("Placeholder", "x").setAttr("dtype", DataType.FLOAT).build().output(0);//建立一個OP
Output y = g.opBuilder("Placeholder", "y").setAttr("dtype", DataType.FLOAT).build().output(0);
Output z = g.opBuilder("Add", "z").addInput(x).addInput(y).build().output(0);
System.out.println( " z= " + z);
// 屢次執行,每次使用不一樣的x和y值
float[] X = new float[] { 1, 2, 3 };
float[] Y = new float[] { 4, 5, 6 };
for (int i = 0; i < X.length; i++) {
try (Tensor tx = Tensor.create(X[i]);
Tensor ty = Tensor.create(Y[i]);
Tensor tz = s.runner().feed("x", tx).feed("y", ty).fetch("z").run().get(0)) {
System.out.println(X[i] + " + " + Y[i] + " = " + tz.floatValue());
}
}
}
Graph graph = new Graph();
Tensor tensor = Tensor.create(2);
Tensor tensor2 = tensor.create(3);
Output output = graph.opBuilder("Const", "mx").setAttr("dtype", tensor.dataType()).setAttr("value", tensor).build().output(0);
Output output2 = graph.opBuilder("Const", "my").setAttr("dtype", tensor2.dataType()).setAttr("value", tensor2).build().output(0);
Output output3 =graph.opBuilder("Sub", "mz").addInput(output).addInput(output2).build().output(0);
Session session = new Session(graph);
Tensor ttt= session.runner().fetch("mz").run().get(0);
System.out.println(ttt.intValue());
Tensor t= session.runner().feed("mx", tensor).feed("my", tensor2).fetch("mz").run().get(0);
System.out.println(t.intValue());
session.close();
tensor.close();
tensor2.close();
graph.close();
}
}
複製代碼
z= <Add 'z:0' shape=<unknown> dtype=FLOAT>
1.0 + 4.0 = 5.0
2.0 + 5.0 = 7.0
3.0 + 6.0 = 9.0
-1
-1
複製代碼