利用 ImageAI 在 COCO 上學習目標檢測

ImageAI是一個python庫,旨在使開發人員可以使用簡單的幾行代碼構建具備包含深度學習和計算機視覺功能的應用程序和系統。 這個 AI Commons 項目https://commons.specpal.science 由 Moses Olafenwa 和 John Olafenwa 開發和維護。爲了更好的使用 ImageAI,我將其 Fork 到 CodeXZone/ImageAI。同時,ImageAI 也提供了中文手冊:imageai。下面我將藉助該教程一步一步的學習目標檢測。html

利用 cocoz 載入 COCO 數據集

首先,利用 cocoz 載入 COCOZ:python

import sys
# 將 cocoapi 添加進入環境變量
sys.path.append(r'D:\API\cocoapi\PythonAPI')
from pycocotools.cocoz import AnnZ, ImageZ, COCOZ
# ------------------

import numpy as np
from matplotlib import pyplot as plt
from IPython import display


def use_svg_display():
    # 用矢量圖顯示, 效果更好
    display.set_matplotlib_formats('svg')


def show_imgs(imgs, k=4):
    '''
    展現 多張圖片
    '''
    n = len(imgs)
    h, w = k, n // k
    assert n == h * w, "圖片數量不匹配"
    use_svg_display()
    _, ax = plt.subplots(h, w, figsize=(5, 5))  # 設置圖的尺寸
    K = np.arange(n).reshape((h, w))
    for i in range(h):
        for j in range(w):
            img = imgs[K[i, j]]
            ax[i][j].imshow(img)
            ax[i][j].axes.get_yaxis().set_visible(False)
            ax[i][j].set_xticks([])
    plt.show()
dataDir = r'E:\Data\coco\images'   # COCO 數據根目錄
dataType = 'train2017'
imgZ = ImageZ(dataDir, dataType)

show_imgs(imgZ[300:316])

物體檢測,提取和微調

import sys
sys.path.append('D:/API/ImageAI')

from imageai.Detection import ObjectDetection
import os

execution_path = os.getcwd()

detector = ObjectDetection()  # 建立目標檢測實例
detector.setModelTypeAsRetinaNet()
detector.setModelPath(
    os.path.join(execution_path, "resnet50_coco_best_v2.0.1.h5"))
detector.loadModel()  # 載入預訓練模型

因爲 detector.detectObjectsFromImage 比較容易支持解壓後的圖片,因此咱們能夠提取出一張圖片來作測試:git

input_image = imgZ.Z.extract(imgZ.names[0]) # 輸入文件的路徑
output_image = os.path.join(execution_path, "image2new.jpg")  # 輸出文件的路徑

detections = detector.detectObjectsFromImage(
    input_image=input_image, output_image_path=output_image)

for eachObject in detections:
    print(eachObject["name"] + " : ", eachObject["percentage_probability"])
    print("--------------------------------")
motorcycle :  99.99607801437378
--------------------------------

detectObjectsFromImage() 函數返回一個字典列表,每一個字典包含圖像中檢測到的對象信息,字典中的對象信息有 name(對象類名)和 percentage_probability(機率)以及 box_points(圖片的左上角與右下角的座標)。github

detections
[{'name': 'motorcycle',
  'percentage_probability': 99.99607801437378,
  'box_points': array([ 34,  92, 546, 427])}]

下面咱們看看其標註框:api

img = plt.imread(output_image)
plt.imshow(img)
plt.show()

爲了直接使用壓縮文件,咱們能夠修改 detectObjectsFromImage 的默認參數 input_type='file'input_type='array'數組

input_image = imgZ[202]  # 輸入文件的路徑
output_image = os.path.join(execution_path, "image2.jpg")  # 輸出文件的路徑

detections = detector.detectObjectsFromImage(
    input_image=input_image, output_image_path=output_image, input_type='array')

for eachObject in detections:
    print(eachObject["name"] + " : ", eachObject["percentage_probability"])
    print("--------------------------------")

img = plt.imread(output_image)
plt.imshow(img)
plt.show()
tennis racket :  54.25310730934143
--------------------------------
person :  99.85058307647705
--------------------------------

detections, objects_path = detector.detectObjectsFromImage(
    input_image=imgZ[900], input_type = 'array',
    output_image_path=os.path.join(execution_path, "image3new.jpg"),
    extract_detected_objects=True)

for eachObject, eachObjectPath in zip(detections, objects_path):
    print(eachObject["name"] + " : ", eachObject["percentage_probability"])
    print("Object's image saved in ", eachObjectPath)
    print("--------------------------------")
person :  56.35678172111511
Object's image saved in  D:\API\CVX\draft\image3new.jpg-objects\person-1.jpg
--------------------------------
person :  75.83483457565308
Object's image saved in  D:\API\CVX\draft\image3new.jpg-objects\person-2.jpg
--------------------------------
person :  60.49004793167114
Object's image saved in  D:\API\CVX\draft\image3new.jpg-objects\person-3.jpg
--------------------------------
person :  85.2730393409729
Object's image saved in  D:\API\CVX\draft\image3new.jpg-objects\person-4.jpg
--------------------------------
person :  83.12703967094421
Object's image saved in  D:\API\CVX\draft\image3new.jpg-objects\person-5.jpg
--------------------------------
bus :  99.7751772403717
Object's image saved in  D:\API\CVX\draft\image3new.jpg-objects\bus-6.jpg
--------------------------------

extract_detected_objects=True 將會把檢測到的對象提取並保存爲單獨的圖像;這將使函數返回 2 個值,第一個是字典數組,每一個字典對應一個檢測到的對象信息,第二個是全部提取出對象的圖像保存路徑,而且它們按照對象在第一個數組中的順序排列。咱們先看看原圖:app

plt.imshow(imgZ[900])
plt.show()

顯示識別出來的對象:svg

show_imgs([plt.imread(fname) for fname in objects_path], 2)

還有一個十分重要的參數 minimum_percentage_probability 用於設定預測機率的閾值,其默認值爲 50(範圍在 \(0-100\)之間)。若是保持默認值,這意味着只有當百分比機率大於等於 50 時,該函數纔會返回檢測到的對象。使用默認值能夠確保檢測結果的完整性,可是在檢測過程當中可能會跳過許多對象。下面咱們看看修改後的效果:函數

detections = detector.detectObjectsFromImage(
    input_image=imgZ[900],
    input_type='array',
    output_image_path=os.path.join(execution_path, "image3new.jpg"),
    minimum_percentage_probability=70)

for eachObject in detections:
    print(eachObject["name"] + " : ", eachObject["percentage_probability"])
    print("--------------------------------")
person :  75.83483457565308
--------------------------------
person :  85.2730393409729
--------------------------------
person :  83.12703967094421
--------------------------------
bus :  99.7751772403717
--------------------------------

咱們將 minimum_percentage_probability 設置爲 70,此時僅僅只能檢測到 4 個。學習

相關文章
相關標籤/搜索