Tensorflow lite Android 人臉檢測demo

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.mdjava

Tensorflow及Object detection API相關環境的搭建安裝python

https://www.jianshu.com/p/286b8163da29android

Bazel安裝git

以上步驟用於Tensorflow物體檢測模型的訓練及Tensorflow到Tensorflow lite的模型轉換,具體步驟後面再講。github

 

下載Android Studio 導入Tensorflow目錄下tflite Android demo。具體目錄在Tensorflow/contrib/lite/example下express

下載相關Jar包。。進行編譯看看是否可以編譯成功,通常來講網絡良好,自動下載好各類包後就會編譯成功,測試生成的apk,默認的是物體檢測。要進行人臉檢測,這裏須要作的就是把相關模型進行替換。apache

 

Demo中的物體檢測模型是基於Tensorflow的ssd-mobilenet-quantized模型,此模型是在coco數據集上訓練。咱們能夠使用此模型作遷移學習來獲得針對於人臉檢測的模型。canvas

預訓練模型能夠在tensorflow object detection的model zoo中下載。網絡

人臉數據集能夠採用WIDER FACE數據集,下載好後利用腳本將圖像及標註信息轉換爲tfrecord格式供訓練使用。app

從research路徑找到對應模型的config文件object_detection/samples/configs/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config

修改其中tfrecord及label的路徑基於checkpoint路徑後開始訓練模型

python train.py \ --logtostderr \ --train_dir=/home/kai/tensorflow/face/ \ --pipeline_config_path=/home/kai/tensorflow/face/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config

 

待loss足夠小時終止訓練,獲得checkpoint。至此咱們獲得了能夠作人臉檢測的模型,但想要在移動端使用tensorflow lite模型還須要一些額外的工做。

Tensorflow Lite是Google設計一種針對移動端的輕量級深度學習模型,它使用quantized kernel等一系列技術使模型更輕便,更快速,而更適合在移動端上使用。

首先須要將checkpoint轉換爲Tensorflow lite可用的pb文件

python object_detection/export_tflite_ssd_graph.py \ --pipeline_config_path=$CONFIG_FILE \ --trained_checkpoint_prefix=$CHECKPOINT_PATH \ --output_directory=$OUTPUT_DIR \ --add_postprocessing_op=true

確保用的是export_tflite_ssd_graph而不是export_inference_graph不然獲得的pb後面沒法轉換。

 獲得tflite_graph.pb後須要利用TOCO將pb模型轉換爲.tflite模型

在TensorFlow目錄下執行

bazel run -c opt tensorflow/contrib/lite/toco:toco -- \ --input_file=$OUTPUT_DIR/tflite_graph.pb \ --output_file=$OUTPUT_DIR/detect.tflite \ --input_shapes=1,300,300,3 \ --input_arrays=normalized_input_image_tensor \ --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \ --inference_type=QUANTIZED_UINT8 \ --mean_values=128 \ --std_values=128 \ --change_concat_input_ranges=false \ --allow_custom_ops

如沒有報錯則會在OUTPUT_DIR目錄下生產一個detect.tflite文件即爲tflite模型

 

在TensorFlow lite demo中添加模型

/* * Copyright 2018 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */

package org.tensorflow.demo; import android.graphics.Bitmap; import android.graphics.Bitmap.Config; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Matrix; import android.graphics.Paint; import android.graphics.Paint.Style; import android.graphics.RectF; import android.graphics.Typeface; import android.media.ImageReader.OnImageAvailableListener; import android.os.SystemClock; import android.util.Size; import android.util.TypedValue; import android.widget.Toast; import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Vector; import org.tensorflow.demo.OverlayView.DrawCallback; import org.tensorflow.demo.env.BorderedText; import org.tensorflow.demo.env.ImageUtils; import org.tensorflow.demo.env.Logger; import org.tensorflow.demo.tracking.MultiBoxTracker; import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds.

/** * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track * objects. */
public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { private static final Logger LOGGER = new Logger(); // Configuration values for the prepackaged SSD face model.
  private static final int TF_OD_API_INPUT_SIZE = 300; private static final boolean TF_OD_API_IS_QUANTIZED = true; private static final String TF_OD_API_MODEL_FILE = "facedetect.tflite"; private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/face.txt"; // Configuration values for the prepackaged SSD Normal model.
  private static final int TF_OD_API_INPUT_SIZE_N = 300; private static final boolean TF_OD_API_IS_QUANTIZED_N = true; private static final String TF_OD_API_MODEL_FILE_N = "detect.tflite"; private static final String TF_OD_API_LABELS_FILE_N = "file:///android_asset/coco_labels_list.txt"; // Which detection model to use: by default uses Tensorflow Object Detection API frozen // checkpoints.
  private enum DetectorMode { TF_OD_API; } private static final DetectorMode MODE = DetectorMode.TF_OD_API; // Minimum detection confidence to track a detection.
  private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.4f; private static final boolean MAINTAIN_ASPECT = false; private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); private static final boolean SAVE_PREVIEW_BITMAP = false; private static final float TEXT_SIZE_DIP = 10; private Integer sensorOrientation; // face detector
  private Classifier detector; // object detector
  private Classifier detector_n; private long lastProcessingTimeMs; private Bitmap rgbFrameBitmap = null; private Bitmap croppedBitmap = null; private Bitmap cropCopyBitmap = null; private boolean computingDetection = false; private long timestamp = 0; private Matrix frameToCropTransform; private Matrix cropToFrameTransform; //tracker
  private MultiBoxTracker tracker; private byte[] luminanceCopy; private BorderedText borderedText; @Override public void onPreviewSizeChosen(final Size size, final int rotation) { final float textSizePx = TypedValue.applyDimension( TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); tracker = new MultiBoxTracker(this); int cropSize = TF_OD_API_INPUT_SIZE; // face detector
    try { detector = TFLiteObjectDetectionAPIModel.create( getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE, TF_OD_API_IS_QUANTIZED); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); Toast toast = Toast.makeText( getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); toast.show(); finish(); } // Normal object detector
    try { detector_n = TFLiteObjectDetectionAPIModel.create( getAssets(), TF_OD_API_MODEL_FILE_N, TF_OD_API_LABELS_FILE_N, TF_OD_API_INPUT_SIZE_N, TF_OD_API_IS_QUANTIZED_N); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); Toast toast = Toast.makeText( getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); toast.show(); finish(); } previewWidth = size.getWidth(); previewHeight = size.getHeight(); sensorOrientation = rotation - getScreenOrientation(); LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888); frameToCropTransform = ImageUtils.getTransformationMatrix( previewWidth, previewHeight, cropSize, cropSize, sensorOrientation, MAINTAIN_ASPECT); cropToFrameTransform = new Matrix(); frameToCropTransform.invert(cropToFrameTransform); trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay); trackingOverlay.addCallback( new DrawCallback() { @Override public void drawCallback(final Canvas canvas) { tracker.draw(canvas); if (isDebug()) { //tracker.drawDebug(canvas);
 } } }); addCallback( new DrawCallback() { @Override public void drawCallback(final Canvas canvas) { if (!isDebug()) { return; } final Bitmap copy = cropCopyBitmap; if (copy == null) { return; } final int backgroundColor = Color.argb(100, 0, 0, 0); canvas.drawColor(backgroundColor); final Matrix matrix = new Matrix(); final float scaleFactor = 2; matrix.postScale(scaleFactor, scaleFactor); matrix.postTranslate( canvas.getWidth() - copy.getWidth() * scaleFactor, canvas.getHeight() - copy.getHeight() * scaleFactor); canvas.drawBitmap(copy, matrix, new Paint()); final Vector<String> lines = new Vector<String>(); if (detector_n != null) { final String statString = detector_n.getStatString(); final String[] statLines = statString.split("\n"); for (final String line : statLines) { lines.add(line); } } lines.add(""); lines.add("Frame: " + previewWidth + "x" + previewHeight); lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); lines.add("Rotation: " + sensorOrientation); lines.add("Inference time: " + lastProcessingTimeMs + "ms"); borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); } }); } OverlayView trackingOverlay; @Override protected void processImage() { ++timestamp; final long currTimestamp = timestamp; byte[] originalLuminance = getLuminance(); tracker.onFrame( previewWidth, previewHeight, getLuminanceStride(), sensorOrientation, originalLuminance, timestamp); trackingOverlay.postInvalidate(); // No mutex needed as this method is not reentrant.
    if (computingDetection) { readyForNextImage(); return; } computingDetection = true; LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread."); rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); if (luminanceCopy == null) { luminanceCopy = new byte[originalLuminance.length]; } System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length); readyForNextImage(); final Canvas canvas = new Canvas(croppedBitmap); canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); // For examining the actual TF input.
    if (SAVE_PREVIEW_BITMAP) { ImageUtils.saveBitmap(croppedBitmap); } runInBackground( new Runnable() { @Override public void run() { LOGGER.i("Running detection on image " + currTimestamp); final long startTime = SystemClock.uptimeMillis(); final List<Classifier.Recognition> results_n = detector_n.recognizeImage(croppedBitmap); final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap); lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); final Canvas canvas = new Canvas(cropCopyBitmap); final Paint paint = new Paint(); paint.setColor(Color.RED); paint.setStyle(Style.STROKE); paint.setStrokeWidth(2.0f); float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; switch (MODE) { case TF_OD_API: minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; break; } final List<Classifier.Recognition> mappedRecognitions =
                new LinkedList<Classifier.Recognition>(); final List<Classifier.Recognition> mappedRecognitions_n =
                new LinkedList<Classifier.Recognition>(); boolean faceflag = true; for (final Classifier.Recognition result : results) { final RectF location = result.getLocation(); if (location != null && result.getConfidence() >= minimumConfidence) { //canvas.drawRect(location, paint);
                faceflag = false; cropToFrameTransform.mapRect(location); result.setLocation(location); mappedRecognitions.add(result); } } if(faceflag) { for (final Classifier.Recognition result_n : results_n) { final RectF location = result_n.getLocation(); String temp = result_n.getTitle(); if (location != null && result_n.getConfidence() >= minimumConfidence && result_n.getTitle().equals("oven") ) { //canvas.drawRect(location, paint);
 cropToFrameTransform.mapRect(location); result_n.setLocation(location); mappedRecognitions_n.add(result_n); } } } tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp); tracker.trackResults(mappedRecognitions_n, luminanceCopy, currTimestamp); trackingOverlay.postInvalidate(); requestRender(); computingDetection = false; } }); } @Override protected int getLayoutId() { return R.layout.camera_connection_fragment_tracking; } @Override protected Size getDesiredPreviewFrameSize() { return DESIRED_PREVIEW_SIZE; } @Override public void onSetDebug(final boolean debug) { detector.enableStatLogging(debug); detector_n.enableStatLogging(debug); } }
相關文章
相關標籤/搜索