【手撕 - 深度學習】TF Lite 魔改:添加自定義 op

做者:LogMnode

本文原載於 https://segmentfault.com/u/logm/articles ,不容許轉載~android


1. 前言

Tensorflow Lite 是 Tensorflow 移動端的版本。git

有關於 Tensorflow 怎麼添加自定義 op,網上有不少博客都講到了,我就不介紹了。而 Tensorflow Lite 由於相對小衆一些,因此網上關於添加自定義 op 的教程不多。github

恰好最近由於項目須要,我在 Tensorflow Lite 中添加了幾個自定義 op。我把個人思考過程以及修改步驟記錄下來,方便有相同需求的同窗參考。segmentfault

我花了大篇幅記錄思考過程和源碼閱讀過程,是但願給其餘小夥伴一些啓發,之後遇到相似的深度學習框架魔改的問題,能夠不依賴網上教程。api

不關心思考過程和源碼閱讀的小夥伴,能夠直接跳到文章的最後,我把修改的步驟作了總結。app

2. 源碼來源

我使用源碼是 Tensorflow v1.13.2框架

Tensorflow Lite 位於 tensorflow/lite 目錄下。ide

3. 官方教程

官網也有關於 Tensorflow Lite 怎麼添加自定義 op 的教程,詳見官方地址函數

官方教程把"怎麼寫自定義 op 的代碼"講得很清楚,遺憾的是沒有詳細說明怎麼把這些新寫的代碼放入到工程中編譯。

4. 進入正題

第1步,找到目標文件夾位置

首先咱們要找到源碼中放置自定義 op 的文件夾位置。有多種尋找的方式:

  1. tensorflow 源碼的目錄結構很是清楚,有過相似框架閱讀經驗的同窗應該立刻能猜出位置;
  2. 官方教程告訴咱們,自定義 op 的代碼要實現 PrepareEval 這兩個函數,那麼咱們使用 grep 命令查找有哪些代碼文件中帶有這兩個函數。

最終,咱們找到的位置是 tensorflow/lite/kernels

找到目標文件夾位置之後,把新增代碼放入該文件夾就能夠了嗎?顯然,沒有這麼簡單。有幾個方面須要考慮:

  1. 代碼邏輯層面,新增代碼的邏輯怎麼與源碼的邏輯鏈接起來;
  2. 編譯層面,新增代碼怎麼參與編譯。

第2步,新增代碼的邏輯怎麼與源碼的邏輯鏈接起來?

有過相似深度學習框架閱讀經驗的同窗應該很快能想到,對於"添加自定義op"這個操做,就是個"op註冊"的過程,因此立刻想到去尋找帶"register"字樣的文件。

而沒有深度學習框架閱讀經驗的同窗也不用慌,官方教程告訴咱們,自定義op在使用前須要調用 AddCustom 函數。那麼很明顯,這個函數就起到了將自定義op的邏輯與源碼邏輯鏈接起來的任務。因此使用 grep 命令查找有哪些代碼文件中帶有這個函數。

兩種方式異曲同工,找到關鍵文件 tensorflow/lite/kernels/register.cc

// 文件:tensorflow/lite/kernels/register.cc
// 行數:22

namespace custom {

TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* Register_RELU_1();

}
// 文件:tensorflow/lite/kernels/register.cc
// 行數:278

  // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
  // custom ops aren't always included by default.
  AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
  AddCustom("AudioSpectrogram",
            tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
  AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
  AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
  AddCustom("TFLite_Detection_PostProcess",
            tflite::ops::custom::Register_DETECTION_POSTPROCESS());

嘿嘿嘿,咱們發現官方源碼中也放了5個自定義op,並且官方偷懶把自定義op與內置op的註冊過程寫在了一塊兒,那麼咱們來看看官方是怎麼寫自定義op的吧,好比 Relu1 這個。

// 文件:tensorflow/lite/kernels/relu1.cc

#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace custom {
namespace relu1 {

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  const TfLiteTensor* input = GetInput(context, node, 0);
  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
  TfLiteTensor* output = GetOutput(context, node, 0);
  output->type = input->type;
  return context->ResizeTensor(context, output,
                               TfLiteIntArrayCopy(input->dims));
}

// This is derived from lite/kernels/activations.cc.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteTensor* input = GetInput(context, node, 0);
  TfLiteTensor* output = GetOutput(context, node, 0);
  const int elements = NumElements(input);
  const float* in = input->data.f;
  const float* in_end = in + elements;
  float* out = output->data.f;
  for (; in < in_end; ++in, ++out) {
    *out = std::min(std::max(0.f, *in), 1.f);
  }
  return kTfLiteOk;
}

}  // namespace relu1

TfLiteRegistration* Register_RELU_1() {
  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
                                 relu1::Prepare, relu1::Eval};
  return &r;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite

能夠看到,與官方給出的教程同樣,關鍵點是實現 PrepareEval 這兩個函數。咱們本身在自定義op的代碼時,能夠把這個文件當作參考模板。

第3步,新增代碼怎麼參與編譯?

這塊須要一些 C++ 大工程開發的知識,Tensorflow 是用 Bazel 做工程編譯的,因此關鍵點在目標文件夾下的 BUILD 文件。

BUILD 文件裏面這麼多的 library,咱們的新代碼應該編譯到哪一個 library 中呢?還記得 官方留的自定義op "Relu1" 嗎?咱們來看看 "Relu1" 是編譯到哪一個 library。

// 文件:tensorflow/lite/kernels/BUILD
// 行數:278

cc_library(
    name = "builtin_op_kernels",
    srcs = [
        ...     // 這裏有不少其餘的源文件
        "mfcc.cc",
        "relu1.cc",
        ...     // 把新寫的代碼文件加到這邊就能夠了
    ],
    hdrs = [
    ],
    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
    visibility = ["//visibility:private"],
    deps = [
        ":activation_functor",
        ":eigen_support",
        ":kernel_util",
        ":lstm_eval",
        ":op_macros",
        ":padding",
        "//tensorflow/lite:framework",
        "//tensorflow/lite:string_util",
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/kernels:gemm_support",
        "//tensorflow/lite/kernels/internal:audio_utils",
        "//tensorflow/lite/kernels/internal:kernel_utils",
        "//tensorflow/lite/kernels/internal:optimized",
        "//tensorflow/lite/kernels/internal:optimized_base",
        "//tensorflow/lite/kernels/internal:quantization_util",
        "//tensorflow/lite/kernels/internal:reference_base",
        "//tensorflow/lite/kernels/internal:tensor",
        "//tensorflow/lite/kernels/internal:tensor_utils",
        "@farmhash_archive//:farmhash",
        "@flatbuffers",
    ],
)

嘿嘿嘿,官方可真會偷懶,自定義 op 和內置 op 一塊兒編譯到 builtin_op_kernels 庫。因此,咱們只要把新的代碼文件添加到 srcs=[] 裏,新的代碼就能參與到編譯過程當中了。

5. 總結

Tensorflow Lite v1.13.2 中,官方偷了個懶,自定義 op 與內置 op 寫在同一個位置,且都是編譯到 builtin_op_kernels 庫。

Tensorflow Lite 的自定義 op 添加方式以下:

  1. 參照 官方教程 以及 tensorflow/lite/kernels/relu1.cc 編寫 op 代碼;
  2. 將 op 代碼放入 tensorflow/lite/kernels 文件夾下;
  3. 修改 tensorflow/lite/kernels/register.cc,完成新增 op 在代碼邏輯上的"註冊";
  4. 修改 tensorflow/lite/kernels/BUILD,將新代碼文件加入到 builtin_op_kernels 庫的編譯過程當中;
  5. 參照 官方教程 從新編譯整個項目。
相關文章
相關標籤/搜索