若是您但願能有一種簡單、高效且靈活的方式把 TensorFlow 模型集成到 Flutter 應用裏,那請您必定不要錯過咱們今天介紹的這個全新插件 tflite_flutter。這個插件的開發者是 Google Summer of Code(GSoC) 的一名實習生 Amish Garg,本文來自他在 Medium 上的一篇文章《在 Flutter 中使用 TensorFlow Lite 插件實現文字分類》。java
tflite_flutter 插件的核心特性:python
本文中,咱們將使用 tflite_flutter 構建一個 文字分類 Flutter 應用 帶您體驗 tflite_flutter 插件,首先重新建一個 Flutter 項目 text_classification_app
開始。linux
將 install.sh
拷貝到您應用的根目錄,而後在根目錄執行 sh install.sh
,本例中就是目錄 text_classification_app/
。android
將 install.bat 文件拷貝到應用根目錄,並在根目錄運行批處理文件 install.bat
,本例中就是目錄 text_classification_app/
。git
它會自動從 release assets 下載最新的二進制資源,而後把它放到指定的目錄下。github
請點擊到 README 文件裏查看更多 關於初始配置的信息。windows
在 pubspec.yaml
添加 tflite_flutter: ^<latest_version>
(詳情)。api
要在移動端上運行 TensorFlow 訓練模型,咱們須要使用 .tflite
格式。若是須要了解如何將 TensorFlow 訓練的模型轉換爲 .tflite
格式,請參閱官方指南。架構
這裏咱們準備使用 TensorFlow 官方站點上預訓練的文字分類模型,可從這裏下載。app
該預訓練的模型能夠預測當前段落的情感是積極仍是消極。它是基於來自 Mass 等人的 Large Movie Review Dataset v1.0 數據集進行訓練的。數據集由基於 IMDB 電影評論所標記的積極或消極標籤組成,點擊查看更多信息。
將 text_classification.tflite
和 text_classification_vocab.txt
文件拷貝到 text_classification_app/assets/ 目錄下。
在 pubspec.yaml
文件中添加 assets/
。
assets: - assets/
如今萬事俱備,咱們能夠開始寫代碼了。 🚀
正如 文字分類模型頁面 裏所提到的。能夠按照下面的步驟使用模型對段落進行分類:
咱們首先寫一個方法對原始字符串進行分詞,其中使用 text_classification_vocab.txt
做爲詞聚集。
在 lib/
文件夾下建立一個新文件 classifier.dart
。
這裏先寫代碼加載 text_classification_vocab.txt
到字典裏。
import 'package:flutter/services.dart'; class Classifier { final _vocabFile = 'text_classification_vocab.txt'; Map<String, int> _dict; Classifier() { _loadDictionary(); } void _loadDictionary() async { final vocab = await rootBundle.loadString('assets/$_vocabFile'); var dict = <String, int>{}; final vocabList = vocab.split('\n'); for (var i = 0; i < vocabList.length; i++) { var entry = vocabList[i].trim().split(' '); dict[entry[0]] = int.parse(entry[1]); } _dict = dict; print('Dictionary loaded successfully'); } }
加載字典
如今咱們來編寫一個函數對原始字符串進行分詞。
import 'package:flutter/services.dart'; class Classifier { final _vocabFile = 'text_classification_vocab.txt'; // 單句的最大長度 final int _sentenceLen = 256; final String start = '<START>'; final String pad = '<PAD>'; final String unk = '<UNKNOWN>'; Map<String, int> _dict; List<List<double>> tokenizeInputText(String text) { // 使用空格進行分詞 final toks = text.split(' '); // 建立一個列表,它的長度等於 _sentenceLen,而且使用 <pad> 的對應的字典值來填充 var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble()); var index = 0; if (_dict.containsKey(start)) { vec[index++] = _dict[start].toDouble(); } // 對於句子裏的每一個單詞在 dict 裏找到相應的 index 值 for (var tok in toks) { if (index > _sentenceLen) { break; } vec[index++] = _dict.containsKey(tok) ? _dict[tok].toDouble() : _dict[unk].toDouble(); } // 按照咱們的解釋器輸入 tensor 所需的形狀 [1,256] 返回 List<List<double>> return [vec]; } }
這是本文的主體部分,這裏咱們會討論 tflite_flutter 插件的用途。
這裏的分析是指基於輸入數據在設備上使用 TensorFlow Lite 模型的處理過程。要使用 TensorFlow Lite 模型進行分析,須要經過 解釋器 來運行它,瞭解更多。
建立解釋器,加載模型
tflite_flutter 提供了一個方法直接經過資源建立解釋器。
static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})
因爲咱們的模型在 assets/
文件夾下,須要使用上面的方法來建立解析器。對於 InterpreterOptions 的相關說明,請 參考這裏。
import 'package:flutter/services.dart'; // 引入 tflite_flutter import 'package:tflite_flutter/tflite_flutter.dart'; class Classifier { // 模型文件的名稱 final _modelFile = 'text_classification.tflite'; // TensorFlow Lite 解釋器對象 Interpreter _interpreter; Classifier() { // 當分類器初始化之後加載模型 _loadModel(); } void _loadModel() async { // 使用 Interpreter.fromAsset 建立解釋器 _interpreter = await Interpreter.fromAsset(_modelFile); print('Interpreter loaded successfully'); } }
建立解釋器的代碼
若是您不但願將模型放在 assets/
目錄下,tflite_flutter 還提供了工廠構造函數建立解釋器,更多信息。
咱們開始進行分析!
如今用下面方法啓動分析:
void run(Object input, Object output);
注意這裏的方法和 Java API 中的是同樣的。
Object input
和 Object output
必須是和 Input Tensor 與 Output Tensor 維度相同的列表。
要查看 input tensors 和 output tensors 的維度,可使用以下代碼:
_interpreter.allocateTensors(); // 打印 input tensor 列表 print(_interpreter.getInputTensors()); // 打印 output tensor 列表 print(_interpreter.getOutputTensors());
在本例中 text_classification 模型的輸出以下:
InputTensorList: [Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data: 1024] OutputTensorList: [Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data: 8]
如今,咱們實現分類方法,該方法返回值爲 1 表示積極,返回值爲 0 表示消極。
int classify(String rawText) { // tokenizeInputText 返回形狀爲 [1, 256] 的 List<List<double>> List<List<double>> input = tokenizeInputText(rawText); // [1,2] 形狀的輸出 var output = List<double>(2).reshape([1, 2]); // run 方法會運行分析而且存儲輸出的值 _interpreter.run(input, output); var result = 0; // 若是輸出中第一個元素的值比第二個大,那麼句子就是消極的 if ((output[0][0] as double) > (output[0][1] as double)) { result = 0; } else { result = 1; } return result; }
用於分析的代碼
在 tflite_flutter 的 extension ListShape on List 下面定義了一些使用的擴展:
// 將提供的列表進行矩陣變形,輸入參數爲元素總數 // 保持相等 // 用法:List(400).reshape([2,10,20]) // 返回 List<dynamic> List reshape(List<int> shape) // 返回列表的形狀 List<int> get shape // 返回列表任意形狀的元素數量 int get computeNumElements
最終的 classifier.dart
應該是這樣的:
import 'package:flutter/services.dart'; // 引入 tflite_flutter import 'package:tflite_flutter/tflite_flutter.dart'; class Classifier { // 模型文件的名稱 final _modelFile = 'text_classification.tflite'; final _vocabFile = 'text_classification_vocab.txt'; // 語句的最大長度 final int _sentenceLen = 256; final String start = '<START>'; final String pad = '<PAD>'; final String unk = '<UNKNOWN>'; Map<String, int> _dict; // TensorFlow Lite 解釋器對象 Interpreter _interpreter; Classifier() { // 當分類器初始化的時候加載模型 _loadModel(); _loadDictionary(); } void _loadModel() async { // 使用 Intepreter.fromAsset 建立解析器 _interpreter = await Interpreter.fromAsset(_modelFile); print('Interpreter loaded successfully'); } void _loadDictionary() async { final vocab = await rootBundle.loadString('assets/$_vocabFile'); var dict = <String, int>{}; final vocabList = vocab.split('\n'); for (var i = 0; i < vocabList.length; i++) { var entry = vocabList[i].trim().split(' '); dict[entry[0]] = int.parse(entry[1]); } _dict = dict; print('Dictionary loaded successfully'); } int classify(String rawText) { // tokenizeInputText 返回形狀爲 [1, 256] 的 List<List<double>> List<List<double>> input = tokenizeInputText(rawText); //輸出形狀爲 [1, 2] 的矩陣 var output = List<double>(2).reshape([1, 2]); // run 方法會運行分析而且將結果存儲在 output 中。 _interpreter.run(input, output); var result = 0; // 若是第一個元素的輸出比第二個大,那麼當前語句是消極的 if ((output[0][0] as double) > (output[0][1] as double)) { result = 0; } else { result = 1; } return result; } List<List<double>> tokenizeInputText(String text) { // 用空格分詞 final toks = text.split(' '); // 建立一個列表,它的長度等於 _sentenceLen,而且使用 <pad> 對應的字典值來填充 var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble()); var index = 0; if (_dict.containsKey(start)) { vec[index++] = _dict[start].toDouble(); } // 對於句子中的每一個單詞,在 dict 中找到相應的 index 值 for (var tok in toks) { if (index > _sentenceLen) { break; } vec[index++] = _dict.containsKey(tok) ? _dict[tok].toDouble() : _dict[unk].toDouble(); } // 按照咱們的解釋器輸入 tensor 所需的形狀 [1,256] 返回 List<List<double>> return [vec]; } }
如今,能夠根據您的喜愛實現 UI 的代碼,分類器的用法比較簡單。
// 建立 Classifier 對象 Classifer _classifier = Classifier(); // 將目標語句做爲參數,調用 classify 方法 _classifier.classify("I liked the movie"); // 返回 1 (積極的) _classifier.classify("I didn't liked the movie"); // 返回 0 (消極的)
請在這裏查閱完整代碼:Text Classification Example app with UI。
文字分類示例應用
瞭解更多關於 tflite_flutter 插件的信息,請訪問 GitHub repo: am15h/tflite_flutter_plugin。
tflite_flutter
和 tflite v1.0.5
有哪些區別?tflite v1.0.5
側重於爲特定用途的應用場景提供高級特性,好比圖片分類、物體檢測等等。而新的 tflite_flutter 則提供了與 Java API 相同的特性和靈活性,並且能夠用於任何 tflite 模型中,它還支持 delegate。
因爲使用 dart:ffi (dart ↔️ (ffi) ↔️ C),tflite_flutter 很是快 (擁有低延時)。而 tflite 使用平臺集成 (dart ↔️ platform-channel ↔️ (Java/Swift) ↔️ JNI ↔️ C)。
更新(07/01/2020): TFLite Flutter Helper 開發庫已發佈。
TensorFlow Lite Flutter Helper Library 爲處理和控制輸入及輸出的 TFLite 模型提供了易用的架構。它的 API 設計和文檔與 TensorFlow Lite Android Support Library 是同樣的。更多信息請 參考這裏。
以上是本文的所有內容,歡迎你們對 tflite_flutter 插件進行反饋,請在這裏 上報 bug 或提出功能需求。
謝謝關注。
感謝 Michael Thomsen。
本文聯合發佈在 TensorFlow 線上討論區、101.dev 和 Flutter 中文文檔,以及 Flutter 社區線上渠道。