做者:LogMnode
本文原載於 https://segmentfault.com/u/logm/articles ,不容許轉載~android
Tensorflow Lite 是 Tensorflow 移動端的版本。git
有關於 Tensorflow 怎麼添加自定義 op,網上有不少博客都講到了,我就不介紹了。而 Tensorflow Lite 由於相對小衆一些,因此網上關於添加自定義 op 的教程不多。github
恰好最近由於項目須要,我在 Tensorflow Lite 中添加了幾個自定義 op。我把個人思考過程以及修改步驟記錄下來,方便有相同需求的同窗參考。segmentfault
我花了大篇幅記錄思考過程和源碼閱讀過程,是但願給其餘小夥伴一些啓發,之後遇到相似的深度學習框架魔改的問題,能夠不依賴網上教程。api
不關心思考過程和源碼閱讀的小夥伴,能夠直接跳到文章的最後,我把修改的步驟作了總結。app
我使用源碼是 Tensorflow v1.13.2框架
Tensorflow Lite 位於 tensorflow/lite
目錄下。ide
官網也有關於 Tensorflow Lite 怎麼添加自定義 op 的教程,詳見官方地址。函數
官方教程把"怎麼寫自定義 op 的代碼"講得很清楚,遺憾的是沒有詳細說明怎麼把這些新寫的代碼放入到工程中編譯。
首先咱們要找到源碼中放置自定義 op 的文件夾位置。有多種尋找的方式:
Prepare
和 Eval
這兩個函數,那麼咱們使用 grep 命令查找有哪些代碼文件中帶有這兩個函數。最終,咱們找到的位置是 tensorflow/lite/kernels
。
找到目標文件夾位置之後,把新增代碼放入該文件夾就能夠了嗎?顯然,沒有這麼簡單。有幾個方面須要考慮:
有過相似深度學習框架閱讀經驗的同窗應該很快能想到,對於"添加自定義op"這個操做,就是個"op註冊"的過程,因此立刻想到去尋找帶"register"字樣的文件。
而沒有深度學習框架閱讀經驗的同窗也不用慌,官方教程告訴咱們,自定義op在使用前須要調用 AddCustom
函數。那麼很明顯,這個函數就起到了將自定義op的邏輯與源碼邏輯鏈接起來的任務。因此使用 grep 命令查找有哪些代碼文件中帶有這個函數。
兩種方式異曲同工,找到關鍵文件 tensorflow/lite/kernels/register.cc
。
// 文件:tensorflow/lite/kernels/register.cc // 行數:22 namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); TfLiteRegistration* Register_LAYER_NORM_LSTM(); TfLiteRegistration* Register_MFCC(); TfLiteRegistration* Register_DETECTION_POSTPROCESS(); TfLiteRegistration* Register_RELU_1(); }
// 文件:tensorflow/lite/kernels/register.cc // 行數:278 // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM()); AddCustom("Relu1", tflite::ops::custom::Register_RELU_1()); AddCustom("TFLite_Detection_PostProcess", tflite::ops::custom::Register_DETECTION_POSTPROCESS());
嘿嘿嘿,咱們發現官方源碼中也放了5個自定義op,並且官方偷懶把自定義op與內置op的註冊過程寫在了一塊兒,那麼咱們來看看官方是怎麼寫自定義op的吧,好比 Relu1
這個。
// 文件:tensorflow/lite/kernels/relu1.cc #include "tensorflow/lite/context.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { namespace custom { namespace relu1 { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TfLiteTensor* output = GetOutput(context, node, 0); output->type = input->type; return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } // This is derived from lite/kernels/activations.cc. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); const int elements = NumElements(input); const float* in = input->data.f; const float* in_end = in + elements; float* out = output->data.f; for (; in < in_end; ++in, ++out) { *out = std::min(std::max(0.f, *in), 1.f); } return kTfLiteOk; } } // namespace relu1 TfLiteRegistration* Register_RELU_1() { static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, relu1::Prepare, relu1::Eval}; return &r; } } // namespace custom } // namespace ops } // namespace tflite
能夠看到,與官方給出的教程同樣,關鍵點是實現 Prepare
和 Eval
這兩個函數。咱們本身在自定義op的代碼時,能夠把這個文件當作參考模板。
這塊須要一些 C++ 大工程開發的知識,Tensorflow 是用 Bazel 做工程編譯的,因此關鍵點在目標文件夾下的 BUILD
文件。
而 BUILD
文件裏面這麼多的 library,咱們的新代碼應該編譯到哪一個 library 中呢?還記得 官方留的自定義op "Relu1" 嗎?咱們來看看 "Relu1" 是編譯到哪一個 library。
// 文件:tensorflow/lite/kernels/BUILD // 行數:278 cc_library( name = "builtin_op_kernels", srcs = [ ... // 這裏有不少其餘的源文件 "mfcc.cc", "relu1.cc", ... // 把新寫的代碼文件加到這邊就能夠了 ], hdrs = [ ], copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, visibility = ["//visibility:private"], deps = [ ":activation_functor", ":eigen_support", ":kernel_util", ":lstm_eval", ":op_macros", ":padding", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels:gemm_support", "//tensorflow/lite/kernels/internal:audio_utils", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:optimized", "//tensorflow/lite/kernels/internal:optimized_base", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:tensor_utils", "@farmhash_archive//:farmhash", "@flatbuffers", ], )
嘿嘿嘿,官方可真會偷懶,自定義 op 和內置 op 一塊兒編譯到 builtin_op_kernels
庫。因此,咱們只要把新的代碼文件添加到 srcs=[]
裏,新的代碼就能參與到編譯過程當中了。
Tensorflow Lite v1.13.2 中,官方偷了個懶,自定義 op 與內置 op 寫在同一個位置,且都是編譯到 builtin_op_kernels
庫。
Tensorflow Lite 的自定義 op 添加方式以下: