使用TensorFlow.js和MobileNet模型在React Native上進行圖像分類

原文連接heartbeat.fritz.ai/image-class…react

最近,針對React NativeExpo應用程序的TensorFlow.jsalpha版本發佈了。目前支持加載預先訓練的模型並訓練新模型,如下是公告推文:ios

圖片
TensorFlow.js提供了許多預訓練的模型,這些模型簡化了從頭開始訓練新機器學習模型的耗時任務。在本教程中,咱們將探索 TensorFlow.jsMobileNet預訓練的模型架構,以對 React Native移動應用程序中的輸入圖像進行分類。 在本教程結束時,該應用程序將以下所示:

注:之前,我確實接觸過Google的Vision API來構建圖像分類應用程序,該應用程序可斷定給定圖像是否爲熱狗。若是您有興趣閱讀該示例,請點擊如下連接:heartbeat.fritz.ai/build-a-not…

本文目錄

  • 環境準備
  • 將TF.js集成到Expo應用程序中
  • 測試TF.js集成
  • 加載MobileNet模型
  • 詢問用戶權限
  • 將原始圖像轉換爲張量
  • 加載和分類圖像
  • 容許用戶選擇圖像
  • 運行應用
  • 結論

完整代碼連接:github.com/amandeepmit…git

環境準備

  • 本地環境Nodejs >= 10.x.x
  • expo-cli
  • 適用於AndroidiOS的Expo Client應用程序,用於測試該APP

將TF.js集成到Expo應用程序中

React Native中使用TensorFlow庫,第一步是集成平臺適配器-- tfjs-react-native模塊,支持從Web加載全部主要的tfjs模型。它還使用expo-gl提供了GPU支持。github

打開終端窗口,並經過執行如下命令來建立新的Expo應用程序。react-native

expo init mobilenet-tfjs-expo
複製代碼

接下來,請確保生成一個由Expo管理的應用程序。而後在app所在目錄中安裝如下依賴項:數組

yarn add @react-native-community/async-storage @tensorflow/tfjs @tensorflow/tfjs-react-native expo-gl @tensorflow-models/mobilenet jpeg-js
複製代碼
注:若是您想使用react-native-cli生成應用程序,則能夠按照明確的說明來修改metro.config.js文件和其餘必要步驟,如此處所述。

即使您使用了Expo,也仍然須要安裝tfjs模塊依賴的async-storagebash

測試TF.js集成

咱們須要確保在呈現應用程序以前將tfjs成功加載到應用程序中。這裏有一個異步函數稱爲tf.ready()。打開App.js文件,導入必要的依賴項,並定義isTfReady初始狀態爲false網絡

import React from 'react'
import { StyleSheet, Text, View } from 'react-native'
import * as tf from '@tensorflow/tfjs'
import { fetch } from '@tensorflow/tfjs-react-native'

class App extends React.Component {
  state = {
    isTfReady: false
  }

  async componentDidMount() {
    await tf.ready()
    this.setState({
      isTfReady: true
    })

    //Output in Expo console
    console.log(this.state.isTfReady)
  }

  render() {
    return (
      <View style={styles.container}>
        <Text>TFJS ready? {this.state.isTfReady ? <Text>Yes</Text> : ''}</Text>
      </View>
    )
  }
}

const styles = StyleSheet.create({
  container: {
    flex: 1,
    backgroundColor: '#fff',
    alignItems: 'center',
    justifyContent: 'center'
  }
})

export default App
複製代碼

因爲生命週期方法是異步的,所以僅在實際加載tfjs時纔會將isTfReady的值更新爲true架構

您能夠在模擬器設備中看到輸出,以下所示。app

或在控制檯中(若是使用console語句)

加載MobileNet模塊

與上一步驟相似,在提供輸入圖像以前,您還必須加載MobileNet模型。從Web上加載通過預先訓練的TensorFlow.js模型是一個昂貴的網絡調用,將花費大量時間。修改App.js文件以加載MobileNet模型。首先導入它:

import * as mobilenet from '@tensorflow-models/mobilenet'
複製代碼

添加初始狀態其餘的屬性:

state = {
  isTfReady: false,
  isModelReady: false
}
複製代碼

修改生命週期方法:

async componentDidMount() {
    await tf.ready()
    this.setState({
      isTfReady: true
    })
    this.model = await mobilenet.load()
    this.setState({ isModelReady: true })
}

複製代碼

最後,當模型加載完成後,讓咱們在屏幕上顯示一個指示器。

<Text>
  Model ready?{' '}
  {this.state.isModelReady ? <Text>Yes</Text> : <Text>Loading Model...</Text>}
</Text>
複製代碼

當模塊加載時,會展現如下的信息:

模塊加載結束,將出現

詢問用戶權限

如今,平臺適配器和模型都已集成在React Native應用程序中,咱們須要添加一個異步功能,以請求用戶的許能夠訪問相機。使用Expo的圖像選擇器組件構建iOS應用程序時,這是必不可少的步驟。 在繼續以前,請運行如下命令以安裝Expo SDK提供的全部軟件包。

expo install expo-permissions expo-constants expo-image-picker
複製代碼

APP.js中添加import聲明

import Constants from 'expo-constants'
import * as Permissions from 'expo-permissions'
複製代碼

APP類中添加方法:

getPermissionAsync = async () => {
  if (Constants.platform.ios) {
    const { status } = await Permissions.askAsync(Permissions.CAMERA_ROLL)
    if (status !== 'granted') {
      alert('Sorry, we need camera roll permissions to make this work!')
    }
  }
}
複製代碼

componentDidMount()內部調用此異步方法:

async componentDidMount() {
    await tf.ready()
    this.setState({
      isTfReady: true
    })
    this.model = await mobilenet.load()
    this.setState({ isModelReady: true })

    // add this
    this.getPermissionAsync()
  }
複製代碼

將原始圖像轉換爲張量

該應用將要求用戶從手機的相機或圖庫中上傳圖像。您必須添加一個方法來加載圖像,並容許TensorFlow解碼圖像中的數據。 TensorFlow支持JPEGPNG格式。

App.js文件中,首先導入jpeg-js程序包,該程序包將用於解碼圖像中的數據。

import * as jpeg from 'jpeg-js'
複製代碼

方法imageToTensor解碼圖片的寬度,高度和二進制數據,該方法接受原始圖像數據的參數。

imageToTensor(rawImageData) {
    const TO_UINT8ARRAY = true
    const { width, height, data } = jpeg.decode(rawImageData, TO_UINT8ARRAY)
    // Drop the alpha channel info for mobilenet
    const buffer = new Uint8Array(width * height * 3)
    let offset = 0 // offset into original data
    for (let i = 0; i < buffer.length; i += 3) {
      buffer[i] = data[offset]
      buffer[i + 1] = data[offset + 1]
      buffer[i + 2] = data[offset + 2]

      offset += 4
    }

    return tf.tensor3d(buffer, [height, width, 3])
  }
複製代碼

TO_UINT8ARRAY數組表示8位無符號整數的數組。構造方法Uint8Array()是新的ES2017語法。對於不一樣的類型化數組,每種類型的數組在內存中都有其本身的字節範圍。

加載和分類圖像

接下來,咱們添加另外一個稱爲classifyImage的方法,該方法將從圖像中讀取原始數據,並在分類後以預測形式產生結果。

必須在應用程序組件的state中保存該圖像源的路徑,以便從源中讀取圖像。一樣,也必須包括上述異步方法產生的結果。 這是最後一次修改App.js文件中的現有狀態。

state = {
  isTfReady: false,
  isModelReady: false,
  predictions: null,
  image: null
}
複製代碼

添加異步方法:

classifyImage = async () => {
  try {
    const imageAssetPath = Image.resolveAssetSource(this.state.image)
    const response = await fetch(imageAssetPath.uri, {}, { isBinary: true })
    const rawImageData = await response.arrayBuffer()
    const imageTensor = this.imageToTensor(rawImageData)
    const predictions = await this.model.classify(imageTensor)
    this.setState({ predictions })
    console.log(predictions)
  } catch (error) {
    console.log(error)
  }
}
複製代碼

預訓練模型的結果以數組形式產生。舉例以下:

容許用戶選擇圖像

從系統設備的相機中選擇圖像,須要使用expo-image-picker包提供的異步方法ImagePicker.launchImageLibraryAsync。導入包:

import * as Permissions from 'expo-permissions'
複製代碼

添加selectImage方法用於:

  • 讓用戶選擇圖片
  • 選擇圖像,在state.image中填充源URI對象
  • 最後,調用classifyImage()方法根據給定的輸入進行預測
selectImage = async () => {
  try {
    let response = await ImagePicker.launchImageLibraryAsync({
      mediaTypes: ImagePicker.MediaTypeOptions.All,
      allowsEditing: true,
      aspect: [4, 3]
    })

    if (!response.cancelled) {
      const source = { uri: response.uri }
      this.setState({ image: source })
      this.classifyImage()
    }
  } catch (error) {
    console.log(error)
  }
}
複製代碼

expo-image-picker返回一個對象。若是用戶取消了選擇圖像的過程,則圖像選擇器模塊將返回單個屬性:canceled:true。若是成功,則圖像選擇器模塊將返回屬性,例如圖像自己的uri。所以,上述片斷中的if語句具備重要的意義。

運行應用

要完成此程序,須要在用戶單擊添加圖像的位置添加不透明度。

這是App.js文件中render方法的完整代碼段:

render() {
    const { isTfReady, isModelReady, predictions, image } = this.state

    return (
      <View style={styles.container}>
        <StatusBar barStyle='light-content' />
        <View style={styles.loadingContainer}>
          <Text style={styles.commonTextStyles}>
            TFJS ready? {isTfReady ? <Text>✅</Text> : ''}
          </Text>

          <View style={styles.loadingModelContainer}>
            <Text style={styles.text}>Model ready? </Text>
            {isModelReady ? (
              <Text style={styles.text}>✅</Text>
            ) : (
              <ActivityIndicator size='small' />
            )}
          </View>
        </View>
        <TouchableOpacity
          style={styles.imageWrapper}
          onPress={isModelReady ? this.selectImage : undefined}>
          {image && <Image source={image} style={styles.imageContainer} />}

          {isModelReady && !image && (
            <Text style={styles.transparentText}>Tap to choose image</Text>
          )}
        </TouchableOpacity>
        <View style={styles.predictionWrapper}>
          {isModelReady && image && (
            <Text style={styles.text}>
              Predictions: {predictions ? '' : 'Predicting...'}
            </Text>
          )}
          {isModelReady &&
            predictions &&
            predictions.map(p => this.renderPrediction(p))}
        </View>
        <View style={styles.footer}>
          <Text style={styles.poweredBy}>Powered by:</Text>
          <Image source={require('./assets/tfjs.jpg')} style={styles.tfLogo} />
        </View>
      </View>
    )
  }
}
複製代碼

完整的styles對象:

const styles = StyleSheet.create({
  container: {
    flex: 1,
    backgroundColor: '#171f24',
    alignItems: 'center'
  },
  loadingContainer: {
    marginTop: 80,
    justifyContent: 'center'
  },
  text: {
    color: '#ffffff',
    fontSize: 16
  },
  loadingModelContainer: {
    flexDirection: 'row',
    marginTop: 10
  },
  imageWrapper: {
    width: 280,
    height: 280,
    padding: 10,
    borderColor: '#cf667f',
    borderWidth: 5,
    borderStyle: 'dashed',
    marginTop: 40,
    marginBottom: 10,
    position: 'relative',
    justifyContent: 'center',
    alignItems: 'center'
  },
  imageContainer: {
    width: 250,
    height: 250,
    position: 'absolute',
    top: 10,
    left: 10,
    bottom: 10,
    right: 10
  },
  predictionWrapper: {
    height: 100,
    width: '100%',
    flexDirection: 'column',
    alignItems: 'center'
  },
  transparentText: {
    color: '#ffffff',
    opacity: 0.7
  },
  footer: {
    marginTop: 40
  },
  poweredBy: {
    fontSize: 20,
    color: '#e69e34',
    marginBottom: 6
  },
  tfLogo: {
    width: 125,
    height: 70
  }
})
複製代碼

從終端窗口執行expo start命令來運行此程序。您會注意到的第一件事是,在Expo客戶端中引導應用程序後,它將要求權限。

而後,一旦模型準備就緒,框中便顯示文本「Tap to choose image」。選擇圖像以查看結果。

預測結果可能須要一些時間。這是先前選擇的圖像的結果。

結論

這篇文章的目的是讓您搶先了解如何在React Native應用中實現TesnorFlow.js模型,以及更好地理解圖像分類,這是基於計算機視覺的機器學習的核心用例。

因爲在撰寫本文時,用於React NativeTF.js處於alpha版本,所以咱們但願未來能看到更多更高級的示例來構建實時應用程序。 這裏有一些我以爲很是有用的資源。 tfjs-react-native GitHub存儲庫,其中包含更多使用不一樣預訓練模型的示例 Infinite RedNSFW JSReact Native示例清晰明瞭,很是有幫助 Tensorflow.js簡介

相關文章
相關標籤/搜索