在實時音視頻中,基於 TensorFlow 實現圖像識別(內附 Demo)

近兩年來,Python在衆多編程語言中的熱度一直穩居前五,熱門程度可見一斑。 Python 擁有很活躍的社區和豐富的第三方庫,Web 框架、爬蟲框架、數據分析框架、機器學習框架等,開發者無需重複造輪子,能夠用 Python 進行 Web 編程、網絡編程,開發多媒體應用,進行數據分析,或實現圖像識別等應用。其中圖像識別是最熱門的應用場景之一,也是與實時音視頻契合度最高的應用場景之一。python

Agora 現已支持 Python 語言。咱們也寫了一份 Python demo,並已分享至 Github。本文將分享TensorFlow 圖像識別的實現,以及在基於 Agora Python SDK的實時音視頻場景中如何集成圖像識別。git

先分享一下 Demo 的識別效果。github

Tensorflow圖片/物體識別

TensorFlow是Google的開源深度學習庫,你可使用這個框架以及Python編程語言,構建大量基於機器學習的應用程序。並且還有不少人把TensorFlow構建的應用程序或者其餘框架,開源發佈到GitHub上。因此咱們今天主要基於Tensorflow學習下物體識別。算法

TensorFlow提供了用於檢測圖片或視頻中所包含物體的API,詳情可參考如下連接:
github.com/tensorflow/…
編程

物體檢測是檢測圖片中所出現的所有物體而且用矩形(Anchor Box)進行標註,物體的類別能夠包括多種,例如人、車、動物、路標等。舉個例子瞭解TensorFlow物體檢測API的使用方法,這裏使用預訓練好的ssd_mobilenet_v1_coco模型(Single Shot MultiBox Detector),更多可用的物體檢測模型能夠參考這裏:github.com/tensorflow/…數組

加載庫網絡

# -*- coding:
utf-8 -*-
 
import numpy as
np
import
tensorflow as tf
import
matplotlib.pyplot as plt
from PIL import
Image
 
from utils
import label_map_util
from utils
import visualization_utils as vis_util
複製代碼

定義一些常量app

PATH_TO_CKPT = 'ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb'
複製代碼
PATH_TO_LABELS = 'ssd_mobilenet_v1_coco_2017_11_17/mscoco_label_map.pbtxt'
複製代碼
NUM_CLASSES = 90
複製代碼

加載預訓練好的模型框架

detection_graph = tf.Graph()
with detection_graph.as_default():
	od_graph_def = tf.GraphDef()
	with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
		od_graph_def.ParseFromString(fid.read())
		tf.import_graph_def(od_graph_def, name='')

複製代碼

加載分類標籤數據機器學習

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
複製代碼
categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes=NUM_CLASSES, use_display_name=True)
複製代碼
category_index = label_map_util.create_category_index(categories)
複製代碼

一個將圖片轉爲數組的輔助函數,以及測試圖片路徑

def load_image_into_numpy_array(image):
	(im_width, im_height) = image.size
	return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
	
TEST_IMAGE_PATHS = ['test_images/image1.jpg', 'test_images/image2.jpg']

複製代碼

使用模型進行物體檢測

with detection_graph.as_default():
	with tf.Session(graph=detection_graph) as sess:
	    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
	    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
	    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
	    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
	    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
	    for image_path in TEST_IMAGE_PATHS:
	    	image = Image.open(image_path)
	    	image_np = load_image_into_numpy_array(image)
	    	image_np_expanded = np.expand_dims(image_np, axis=0)
	    	(boxes, scores, classes, num) = sess.run(
	    		[detection_boxes, detection_scores, detection_classes, num_detections], 
	    		feed_dict={image_tensor: image_np_expanded})
	    	
	    	vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8)
	    	plt.figure(figsize=[12, 8])
	    	plt.imshow(image_np)
	    	plt.show()

複製代碼

檢測結果以下,第一張圖片檢測出了兩隻狗狗

python1.png

實時音視頻場景下Tensorflow物體識別

既然Tensorflow在靜態圖片的物體識別已經相對成熟,那在現實場景中,大量的實時音視頻互動場景中,如何來作物體識別?咱們如今基於聲網實時視頻的SDK,闡述如何作物體識別。

首先咱們瞭解視頻其實就是由一幀一幀的圖像組合而成,因此從這個層面來講,視頻中的目標識別就是從每一幀圖像中作目標識別,從這個層面上講,兩者沒有本質區別。在理解這個前提的基礎上,咱們就能夠相對簡單地作實時音視頻場景下Tensorflow物體識別。

(1)讀取Agora實時音視頻,截取遠端視頻流的圖片

def onRenderVideoFrame(uid, width, height, yStride, uStride, vStride, yBuffer, uBuffer, vBuffer, rotation, renderTimeMs, avsync_type):
         # 用 isImageDetect 字段判斷前一幀圖像是否已完成識別,若完成置爲True,執行如下代碼,執行完置爲false
        if EventHandlerData.isImageDetect:
            y_array = (ctypes.c_uint8 * (width * height)).from_address(yBuffer)
            u_array = (ctypes.c_uint8 * ((width // 2) * (height // 2))).from_address(uBuffer)
            v_array = (ctypes.c_uint8 * ((width // 2) * (height // 2))).from_address(vBuffer)

            Y = np.frombuffer(y_array, dtype=np.uint8).reshape(height, width)
            U = np.frombuffer(u_array, dtype=np.uint8).reshape((height // 2, width // 2)).repeat(2, axis=0).repeat(2, axis=1)
            V = np.frombuffer(v_array, dtype=np.uint8).reshape((height // 2, width // 2)).repeat(2, axis=0).repeat(2, axis=1)
            YUV = np.dstack((Y, U, V))[:height, :width, :]
            # AI模型中大多數模型都是RGB格式訓練,聲網提供的視頻回調數據源是YUV格式,咱們作下格式轉換
            RGB = cv2.cvtColor(YUV, cv2.COLOR_YUV2RGB, 3)
            EventHandlerData.image = Image.fromarray(RGB)
            EventHandlerData.isImageDetect = False
複製代碼

(2)Tensorflow對截取圖片進行物體識別

class objectDetectThread(QThread):
    objectSignal = pyqtSignal(str)
    def __init__(self):
        super().__init__()
    def run(self):
        detection_graph = EventHandlerData.detection_graph
        with detection_graph.as_default():
            with tf.Session(graph=detection_graph) as sess:
                (im_width, im_height) = EventHandlerData.image.size
                image_np = np.array(EventHandlerData.image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
                image_np_expanded = np.expand_dims(image_np, axis=0)
                image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = detection_graph.get_tensor_by_name('detection_scores:0')
                classes = detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = detection_graph.get_tensor_by_name('num_detections:0')
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                objectText = []
                # 若是識別機率大於百分之四十,咱們就在文本框內顯示所識別物體
                for i, c in enumerate(classes[0]):
                    if scores[0][i] > 0.4
                        object = EventHandlerData.category_index[int(c)]['name']
                        if object not in objectText:
                            objectText.append(object)
                    else:
                        break
                self.objectSignal.emit(', '.join(objectText))
                EventHandlerData.detectReady = True
                # 本幀圖片識別完,isImageDetect 字段置爲True,再次開始讀取並轉換Agora遠端實時音視頻
                EventHandlerData.isImageDetect = True
複製代碼

咱們已經將這個 Demo 以及 Agora Python SDK 上傳至 Github,你們能夠直接下載使用。

Agora Python TensorFlow Demo編譯指南:

  • 下載Agora Python SDK ,下載地址如上。
  • 如果 Windows,複製.pyd and .dll文件到本項目文件夾根目錄;如果IOS,複製.so文件到本文件夾根目錄
  • 下載 Tensorflow模型,而後把 object_detection 文件複製.到本文件夾根目錄
  • 安裝 Protobuf。而後運行: protoc object_detection/protos/*.proto --python_out=.
  • 從這裏下載預先訓練的模型(下載連接)
  • 推薦使用 ssd_mobilenet_v1_coco 和 ssdlite_mobilenet_v2_coco,由於他們相對運行較快
  • 提取 frozen graph,命令行運行:python extractGraph.py --model_file='FILE_NAME_OF_YOUR_MODEL'
  • 最後,在 callBack.py 中修改 model name,在 demo.py 中修改Appid,而後運行便可

請注意,這個 Demo,咱們僅做爲演示,從獲取到遠端實時視頻畫面,到TensorFlow 進行識別處理,再到顯示出識別效果,期間須要2至4 秒(視網絡狀況而定)。不一樣性能的機器、算法模型,其識別的效率也不一樣。感興趣的開發者能夠嘗試本身更換算法模型,來優化識別的延時。

若是 Demo 運行中遇到問題,請在 RTC 開發者社區反饋、交流,或在 Github 提 issue。

相關文章
相關標籤/搜索