Android中使用TensorFlow Lite實現圖像分類

前言html

TensorFlow Lite是一款專門針對移動設備的深度學習框架,移動設備深度學習框架是部署在手機或者樹莓派等小型移動設備上的深度學習框架,能夠使用訓練好的模型在手機等設備上完成推理任務。這一類框架的出現,能夠使得一些推理的任務能夠在本地執行,不須要再調用服務器的網絡接口,大大減小了預測時間。在前幾篇文章中已經介紹了百度的paddle-mobile,小米的mace,還有騰訊的ncnn。這在本章中咱們將介紹谷歌的TensorFlow Lite。java

TensorFlow Lite的GitHub地址:node

github.com/tensorflow/…python

正文android

轉換模型git

手機上執行預測,首先須要一個訓練好的模型,這個模型不能是TensorFlow原來格式的模型,TensorFlow Lite使用的模型格式是另外一種格式的模型。github

下面就介紹如何使用這個格式的模型。 獲取模型主要有兩種方法,第一種是在訓練的時候就保存tflite模型,另一種就是使用其餘格式的TensorFlow模型轉換成tflite模型。ubuntu

最方便的就是在訓練的時候保存tflite格式的模型,主要是使用到tf.contrib.lite.toco_convert()接口,下面就是一個簡單的例子:bash

import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
 tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
 open("converteds_model.tflite", "wb").write(tflite_model)
複製代碼

最後得到的converteds_model.tflite文件就能夠直接在TensorFlow Lite上使用。服務器

第二種就是把tensorflow保存的其餘模型轉換成tflite,咱們能夠在如下的連接下載模型。tensorflow模型地址以下所示:

github.com/tensorflow/…

上面提供的模型同時也包括了tflite模型,咱們能夠直接拿來使用,可是咱們也能夠使用其餘格式的模型來轉換。好比咱們下載一個mobilenet_v1_1.0_224.tgz,解壓以後得到如下文件:

mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 mobilenet_v1_1.0_224_eval.pbtxt mobilenet_v1_1.0_224.tflite
mobilenet_v1_1.0_224.ckpt.index mobilenet_v1_1.0_224_frozen.pb
mobilenet_v1_1.0_224.ckpt.meta mobilenet_v1_1.0_224_info.txt
複製代碼

首先要安裝Bazel,能夠參考:

docs.bazel.build/versions/ma…

只須要完成Installing using binary installer這一部分便可。而後克隆TensorFlow的源碼:

git clone https://github.com/tensorflow/tensorflow.git
複製代碼

接着編譯轉換工具,這個編譯時間可能比較長:

cd tensorflow/
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/contrib/lite/toco:toco
複製代碼

得到到轉換工具以後,咱們就能夠開始轉換模型了,如下操做是凍結圖。

  • input_graph對應的是.pb文件;
  • input_checkpoint對應的是mobilenet_v1_1.0_224.ckpt.data-00000-of-00001,可是在使用的使用是去掉後綴名的。
  • output_node_names這個能夠在mobilenet_v1_1.0_224_info.txt中獲取。 不過要注意的是咱們下載的模型已是凍結過來,因此不用再執行這個操做。但若是是其餘的模型,要先凍結圖,而後再執行以後的操做。
./freeze_graph --input_graph=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb 
 --input_checkpoint=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt 
 --input_binary=true 
 --output_graph=/tmp/frozen_mobilenet_v1_224.pb 
 --output_node_names=MobilenetV1/Predictions/Reshape_1
複製代碼

如下操做就是把已經凍結的圖轉換成.tflite:

  • input_file是已經凍結的圖;
  • output_file是轉換後輸出的路徑;
  • output_arrays這個能夠在mobilenet_v1_1.0_224_info.txt中獲取;
  • input_shapes這個是預測數據的shape
./toco --input_file=/tmp/mobilenet_v1_1.0_224_frozen.pb 
 --input_format=TENSORFLOW_GRAPHDEF 
 --output_format=TFLITE 
 --output_file=/tmp/mobilenet_v1_1.0_224.tflite 
 --inference_type=FLOAT 
 --input_type=FLOAT 
 --input_arrays=input 
 --output_arrays=MobilenetV1/Predictions/Reshape_1 
 --input_shapes=1,224,224,3
複製代碼

通過上面的步驟就能夠獲取到mobilenet_v1_1.0_224.tflite模型了,以後咱們會在Android項目中使用它。

開發Android項目

有了上面的模型以後,咱們就使用Android Studio建立一個Android項目,一路默認就能夠了,並不須要C++的支持,由於咱們使用到的TensorFlow Lite是Java代碼的,開發起來很是方便。

一、建立完成以後,在app目錄下的build.gradle配置文件加上如下配置信息: 在dependencies下加上包的引用,第一個是圖片加載框架Glide,第二個就是咱們這個項目的核心TensorFlow Lite:

implementation 'com.github.bumptech.glide:glide:4.3.1'
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
複製代碼

而後在android下加上如下代碼,這個主要是限制不要對tensorflow lite的模型進行壓縮,壓縮以後就沒法加載模型了:

//set no compress models
 aaptOptions {
 noCompress "tflite"
 }
複製代碼

在main目錄下建立assets文件夾,這個文件夾主要是存放tflite模型和label名稱文件。

如下是主界面的代碼MainActivity.java,這個代碼比較長,咱們來分析這段代碼,重要的方法介紹以下:

  • loadModelFile()方法是把模型文件讀取成MappedByteBuffer,以後給Interpreter類初始化模型,這個模型存放在main的assets目錄下。
  • load_model()方法是加載模型,並獲得一個對象tflite,以後就是使用這個對象來預測圖像,同時能夠使用這個對象設置一些參數,好比設置使用的線程數量tflite.setNumThreads(4);
  • showDialog()方法是顯示彈窗,經過這個彈窗的選擇不一樣的模型。 readCacheLabelFromLocalFile()方法是讀取文件種分類標籤對應的名稱,這個文件比較長,能夠參考這篇文章獲取標籤名稱,也能夠下載筆者的項目,裏面有對用的文件。這個文件cacheLabel.txt跟模型同樣存放在assets目錄下。
  • predict_image()方法是預測圖片並顯示結果的,預測的流程是:獲取圖片的路徑,而後使用對圖片進行壓縮,以後把圖片轉換成ByteBuffer格式的數據,最後調用tflite.run()方法進行預測。
  • get_max_result()方法是獲取最大機率的標籤。 ###大段代碼來襲
package com.yeyupiaoling.testtflite;
import android.Manifest;
import android.app.Activity;
import android.content.DialogInterface;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.support.annotation.NonNull;
import android.support.annotation.Nullable;
import android.support.v4.app.ActivityCompat;
import android.support.v4.content.ContextCompat;
import android.support.v7.app.AlertDialog;
import android.support.v7.app.AppCompatActivity;
import android.text.method.ScrollingMovementMethod;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import com.bumptech.glide.Glide;
import com.bumptech.glide.load.engine.DiskCacheStrategy;
import com.bumptech.glide.request.RequestOptions;
import org.tensorflow.lite.Interpreter;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
public class MainActivity extends AppCompatActivity {
 private static final String TAG = MainActivity.class.getName();
 private static final int USE_PHOTO = 1001;
 private static final int START_CAMERA = 1002;
 private String camera_image_path;
 private ImageView show_image;
 private TextView result_text;
 private String assets_path = "lite_images";
 private boolean load_result = false;
 private int[] ddims = {1, 3, 224, 224};
 private int model_index = 0;
 private List<String> resultLabel = new ArrayList<>();
 private Interpreter tflite = null;
 private static final String[] PADDLE_MODEL = {
 "mobilenet_v1",
 "mobilenet_v2"
 };
 @Override
 protected void onCreate(Bundle savedInstanceState) {
 super.onCreate(savedInstanceState);
 setContentView(R.layout.activity_main);
 init_view();
 readCacheLabelFromLocalFile();
 }
 // initialize view
 private void init_view() {
 request_permissions();
 show_image = (ImageView) findViewById(R.id.show_image);
 result_text = (TextView) findViewById(R.id.result_text);
 result_text.setMovementMethod(ScrollingMovementMethod.getInstance());
 Button load_model = (Button) findViewById(R.id.load_model);
 Button use_photo = (Button) findViewById(R.id.use_photo);
 Button start_photo = (Button) findViewById(R.id.start_camera);
 load_model.setOnClickListener(new View.OnClickListener() {
 @Override
 public void onClick(View view) {
 showDialog();
 }
 });
 // use photo click
 use_photo.setOnClickListener(new View.OnClickListener() {
 @Override
 public void onClick(View view) {
 if (!load_result) {
 Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
 return;
 }
 PhotoUtil.use_photo(MainActivity.this, USE_PHOTO);
 }
 });
 // start camera click
 start_photo.setOnClickListener(new View.OnClickListener() {
 @Override
 public void onClick(View view) {
 if (!load_result) {
 Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
 return;
 }
 camera_image_path = PhotoUtil.start_camera(MainActivity.this, START_CAMERA);
 }
 });
 }
 /**
 * Memory-map the model file in Assets.
 */
 private MappedByteBuffer loadModelFile(String model) throws IOException {
 AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
 FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
 FileChannel fileChannel = inputStream.getChannel();
 long startOffset = fileDescriptor.getStartOffset();
 long declaredLength = fileDescriptor.getDeclaredLength();
 return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
 }
 // load infer model
 private void load_model(String model) {
 try {
 tflite = new Interpreter(loadModelFile(model));
 Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
 Log.d(TAG, model + " model load success");
 tflite.setNumThreads(4);
 load_result = true;
 } catch (IOException e) {
 Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
 Log.d(TAG, model + " model load fail");
 load_result = false;
 e.printStackTrace();
 }
 }
 public void showDialog() {
 AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
 // set dialog title
 builder.setTitle("Please select model");
 // set dialog icon
 builder.setIcon(android.R.drawable.ic_dialog_alert);
 // able click other will cancel
 builder.setCancelable(true);
 // cancel button
 builder.setNegativeButton("cancel", null);
 // set list
 builder.setSingleChoiceItems(PADDLE_MODEL, model_index, new DialogInterface.OnClickListener() {
 @Override
 public void onClick(DialogInterface dialog, int which) {
 model_index = which;
 load_model(PADDLE_MODEL[model_index]);
 dialog.dismiss();
 }
 });
 // show dialog
 builder.show();
 }
 private void readCacheLabelFromLocalFile() {
 try {
 AssetManager assetManager = getApplicationContext().getAssets();
 BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel.txt")));
 String readLine = null;
 while ((readLine = reader.readLine()) != null) {
 resultLabel.add(readLine);
 }
 reader.close();
 } catch (Exception e) {
 Log.e("labelCache", "error " + e);
 }
 }
 @Override
 protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
 String image_path;
 RequestOptions options = new RequestOptions().skipMemoryCache(true).diskCacheStrategy(DiskCacheStrategy.NONE);
 if (resultCode == Activity.RESULT_OK) {
 switch (requestCode) {
 case USE_PHOTO:
 if (data == null) {
 Log.w(TAG, "user photo data is null");
 return;
 }
 Uri image_uri = data.getData();
 Glide.with(MainActivity.this).load(image_uri).apply(options).into(show_image);
 // get image path from uri
 image_path = PhotoUtil.get_path_from_URI(MainActivity.this, image_uri);
 // predict image
 predict_image(image_path);
 break;
 case START_CAMERA:
 // show photo
 Glide.with(MainActivity.this).load(camera_image_path).apply(options).into(show_image);
 // predict image
 predict_image(camera_image_path);
 break;
 }
 }
 }
 // predict image
 private void predict_image(String image_path) {
 // picture to float array
 Bitmap bmp = PhotoUtil.getScaleBitmap(image_path);
 ByteBuffer inputData = PhotoUtil.getScaledMatrix(bmp, ddims);
 try {
 // Data format conversion takes too long
 // Log.d("inputData", Arrays.toString(inputData));
 float[][] labelProbArray = new float[1][1001];
 long start = System.currentTimeMillis();
 // get predict result
 tflite.run(inputData, labelProbArray);
 long end = System.currentTimeMillis();
 long time = end - start;
 float[] results = new float[labelProbArray[0].length];
 System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
 // show predict result and time
 int r = get_max_result(results);
 String show_text = "result:" + r + " name:" + resultLabel.get(r) + " probability:" + results[r] + " time:" + time + "ms";
 result_text.setText(show_text);
 } catch (Exception e) {
 e.printStackTrace();
 }
 // get max probability label
 private int get_max_result(float[] result) {
 float probability = result[0];
 int r = 0;
 for (int i = 0; i < result.length; i++) {
 if (probability < result[i]) {
 probability = result[i];
 r = i;
 }
 }
 return r;
 }
 // request permissions
 private void request_permissions() {
 List<String> permissionList = new ArrayList<>();
 if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
 permissionList.add(Manifest.permission.CAMERA);
 }
 if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
 permissionList.add(Manifest.permission.WRITE_EXTERNAL_STORAGE);
 }
 if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
 permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE);
 }
 // if list is not empty will request permissions
 if (!permissionList.isEmpty()) {
 ActivityCompat.requestPermissions(this, permissionList.toArray(new String[permissionList.size()]), 1);
 }
 }
 @Override
 public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
 super.onRequestPermissionsResult(requestCode, permissions, grantResults);
 switch (requestCode) {
 case 1:
 if (grantResults.length > 0) {
 for (int i = 0; i < grantResults.length; i++) {
 int grantResult = grantResults[i];
 if (grantResult == PackageManager.PERMISSION_DENIED) {
 String s = permissions[i];
 Toast.makeText(this, s + " permission was denied", Toast.LENGTH_SHORT).show();
 }
 }
 }
 break;
 }
 }
}
複製代碼

AndroidManifest.xml下加上申請的權限,用到了相機和讀取外部存儲的內存:

<uses-permission android:name="android.permission.CAMERA"/>
 <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
 <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
複製代碼

而後還要在application下加上如下的配置信息,這個主要是爲了兼容Android 7.0的相機:

<!-- FileProvider配置訪問路徑,適配7.0及其以上 -->
 <provider
 android:name="android.support.v4.content.FileProvider"
 android:authorities="com.yeyupiaoling.testtflite.fileprovider"
 android:exported="false"
 android:grantUriPermissions="true">
 <meta-data
 android:name="android.support.FILE_PROVIDER_PATHS"
 android:resource="@xml/file_paths"/>
 </provider>
複製代碼

以後在res建立一個xml目錄,而後建立一個file_paths.xml文件,在這個文件中加上如下代碼,這個是咱們拍照以後圖片存放的位置:

<?xml version="1.0" encoding="utf-8"?>
<resources>
 <external-path
 name="images"
 path="lite_mobile/" />
</resources>
複製代碼

主界面佈局代碼activity_main.xml:

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
 xmlns:app="http://schemas.android.com/apk/res-auto"
 xmlns:tools="http://schemas.android.com/tools"
 android:layout_width="match_parent"
 android:layout_height="match_parent"
 tools:context=".MainActivity">
 <LinearLayout
 android:id="@+id/btn1_ll"
 android:layout_width="match_parent"
 android:layout_height="wrap_content"
 android:layout_alignParentBottom="true"
 android:orientation="horizontal">
 <Button
 android:id="@+id/use_photo"
 android:layout_width="0dp"
 android:layout_height="wrap_content"
 android:layout_weight="1"
 android:text="相冊" />
 <Button
 android:id="@+id/start_camera"
 android:layout_width="0dp"
 android:layout_height="wrap_content"
 android:layout_weight="1"
 android:text="拍照" />
 </LinearLayout>
 <LinearLayout
 android:id="@+id/btn2_ll"
 android:layout_width="match_parent"
 android:layout_height="wrap_content"
 android:layout_above="@id/btn1_ll"
 android:orientation="horizontal">
 <Button
 android:id="@+id/load_model"
 android:layout_width="0dp"
 android:layout_height="wrap_content"
 android:layout_weight="1"
 android:text="加載模型" />
 </LinearLayout>
 <TextView
 android:id="@+id/result_text"
 android:layout_width="match_parent"
 android:layout_height="150dp"
 android:layout_above="@id/btn2_ll"
 android:hint="預測結果會在這裏顯示"
 android:inputType="textMultiLine"
 android:textSize="16sp"
 tools:ignore="TextViewEdits" />
 <ImageView
 android:id="@+id/show_image"
 android:layout_width="match_parent"
 android:layout_height="match_parent"
 android:layout_above="@id/result_text"
 android:layout_alignParentTop="true" />
</RelativeLayout>
複製代碼

如下就是效果圖片:

Android中使用TensorFlow Lite實現圖像分類
有何疑問或者看法的能夠留言討論哦~
相關文章
相關標籤/搜索