1. 安裝tensorflow和golang(參考https://tensorflow.google.cn/install/install_go)node
2. python訓練模型,這裏以keras example的imdb_cnn.py爲例:python
# coding:utf-8 import tensorflow as tf from keras.models import Sequential from keras.layer import Embedding, Dropout, Conv1D, Dense, GlobalMaxPooling1D from keras.preprocessing import sequence from keras.datasets import imdb from keras import backend as K # 代碼源於keras example的 imdb_cnn.py max_features = 5000 maxlen = 20 batch_size = 32 embedding_dims = 50 filters = 250 kernel_size = 3 hidden_dims = 250 epochs = 2 # 讀取數據 print('Loading data...') (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) print(len(x_train), 'train sequences') print(len(x_test), 'test sequences') print('Pad sequences (samples x time)') x_train = sequence.pad_sequences(x_train, maxlen=maxlen) x_test = sequence.pad_sequences(x_test, maxlen=maxlen) print('x_train shape:', x_train.shape) print('x_test shape:', x_test.shape) # 定義模型 # 這裏對每一個層都添加一個name model = Sequential() model.add(Embedding(max_features + 1, 100, name="input_layer")) model.add(Dropout(0.2, name="dropout_layer1")) model.add(Conv1D(filters=256, kernel_size=5, padding="valid", activation="relu", strides=1, name="cnn_layer")) model.add(Dropout(0.2, name="dropout_layer2")) model.add(GlobalMaxPooling1D(name="maxpooling_layer")) model.add(Dense(256, activation='relu', name="dense_layer1")) model.add(Dropout(0.2, name="dropout_layer3")) model.add(Dense(2, activation='softmax', name="output_layer")) # 模型訓練 sess = tf.Session() K.set_session(sess) # 這步找到input_layer和output_layer的完整路徑,在golang中使用時須要用來定義輸入輸出node for n in sess.graph.as_graph_def().node: if 'input_layer' in n.name: print(n.name) if 'output_layer' in n.name: print(n.name) model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), shuffle=1) # 如下是關鍵代碼 # Use TF to save the graph model instead of Keras save model to load it in Golang builder = tf.saved_model.builder.SavedModelBuilder("cnnModel") # Tag the model, required for Go builder.add_meta_graph_and_variables(sess, ["myTag"]) builder.save() sess.close()
生成cnnModel文件夾,裏面包括了一個.pd文件和variables文件夾。git
3. golang中使用訓練好的模型github
package main import ( "fmt" tf "github.com/tensorflow/tensorflow/tensorflow/go" ) func main() { // 句子最大長度 const MAXLEN int = 20 // 將文本轉換爲id序列,爲了實驗方便直接使用轉換好的ID序列便可 input_data := [1][MAXLEN]float32{{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 208.0, 659.0, 180.0, 408.0, 42.0, 547.0, 829.0, 285.0, 334.0, 42.0, 642.0, 81.0, 800.0}} tensor, err := tf.NewTensor(input_data) if err != nil { fmt.Printf("Error NewTensor: err: %s", err.Error()) return } //讀取模型 model, err := tf.LoadSavedModel("cnnModel", []string{"myTag"}, nil) if err != nil { fmt.Printf("Error loading Saved Model: %s\n", err.Error()) return } // 識別 result, err := model.Session.Run( map[tf.Output]*tf.Tensor{ // python版tensorflow/keras中定義的輸入層input_layer model.Graph.Operation("input_layer").Output(0): tensor, }, []tf.Output{ // python版tensorflow/keras中定義的輸出層output_layer model.Graph.Operation("output_layer/Softmax").Output(0), }, nil, ) if err != nil { fmt.Printf("Error running the session with input, err: %s ", err.Error()) return } // 輸出結果,interface{}格式 fmt.Printf("Result value: %v", result[0].Value()) }
4. 預測性能對比:golang
任務:文本二分類session
測試樣本數:6萬ide
padding長度:200性能
平臺:只使用CPU測試
python用時110秒,golang用時30秒,聽說圖像識別速度會相差10倍以上。ui