Tensorflow Lite tflite模型的生成與導入

假如想要在ARM板上用 tensorflow lite,那麼意味着必需要把PC上的模型生成 tflite文件,而後在ARM上導入這個 tflite文件,經過解析這個文件來進行計算。
根據前面所說, tensorflow的全部計算都會在內部生成一個圖,包括變量的初始化,輸入定義等,那麼即使不是通過訓練的神經網絡模型,只是簡單的三角函數計算,也能夠生成一個 tflite模型用於在 tensorflow lite上導入。因此,這裏我就只作了簡單的 sin()計算來跑一編這個流程。

生成tflite模型

這部分主要是調用TFLiteConverter函數,直接生成tflite文件,再也不經過pb文件轉化。
先上代碼:python

import numpy as np import time import math import tensorflow as tf SIZE = 1000 X = np.random.rand(SIZE, 1) X = X*(math.pi/2.0) start = time.time() x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input') x2 = tf.placeholder(tf.float32, [SIZE, 1], name='x2-input') y1 = tf.sin(x1) y2 = tf.sin(x2) y = y1*y2 with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) converter = tf.lite.TFLiteConverter.from_session(sess, [x1, x2], [y]) tflite_model = converter.convert() open("/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite", "wb").write(tflite_model) end = time.time() print("2nd ", str(end - start))
轉化函數
主要遇到的問題是 tensorflow的變化實在太快,這些個轉化函數一直在變。位置也一直在變,如今參考 官方文檔,是按上面代碼中調用,不然就會報找不到 lite之類的錯誤。我如今PC上的 tensorflow Python版本是1.13,因此 lite已經在 contrib外面了,若是是之前的版本,要按文檔中下面這樣調用。
 
TensorFlow Version Python API
1.12 tf.contrib.lite.TFLiteConverter
1.9-1.11 tf.contrib.lite.TocoConverter
1.7-1.8 tf.contrib.lite.toco_convert

 

 

 

 

 

輸入參數shapegit

原本在本文件中爲了給定的輸入數據大小自由,x1,x2shape會寫成[None, 1],可是若是這樣寫,轉化成tflite模型後會默認爲[1,1],並不能自由接收數據大小,因此在這裏要指定大小SIZEgithub

x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input')

導入tflite模型

原本這部分應該是在ARM板子上作的,可是爲了驗證tflite文件的可用性,我先在PC的Python上試驗。先上代碼:api

import tensorflow as tf import numpy as np import math import time SIZE = 1000 X = np.random.rand(SIZE, 1, ).astype(np.float32) X = X*(math.pi/2.0) start = time.time() interpreter = tf.lite.Interpreter(model_path="/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite") interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() interpreter.set_tensor(input_details[0]['index'], X) interpreter.set_tensor(input_details[1]['index'], X) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) end = time.time() print("1st ", str(end - start))
首先根據 tflite文件生成解析器,而後用 allocate_tensors()分配內存。將輸入經過 set_tensor傳入,而後調用 invoke()來真正運行。最後獲得輸出。
Python跑的時候能夠很清楚的看到 input_details的數據結構。官方的例子是隻傳入一個數據,因此只須要取 input_details[0],而我傳入了2個輸入,因此須要設置2個。同時能夠看到 input_details的2個數據的名字都是我在以前設置的 x1-inputx2-input,這樣很是好理解。
輸入參數類型
這裏有個坑是輸入參數的類型必定要注意。我在生成模型的時候定義的輸入參數類型是 tf.float32,而在導入的時候若是直接是 X = np.random.rand(SIZE, 1, )的話,會報錯:
ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 3

這裏把經過astype(np.float32)把輸入參數指定爲float32就OK了。網絡

  • 操做不支持的坑
    能夠從前面的代碼裏看到我寫了兩個sin(),其實一開始是一個sin()一個cos()的,可是好像默認的tflite模型不支持cos()操做,沒法生成,因此我只好暫時先只寫sin(),後面再研究怎麼把cos()加上。
相關文章
相關標籤/搜索