機器學習進階-目標追蹤-SSD多進程執行 1.cv2.dnn.readnetFromCaffe(用於讀取已經訓練好的caffe模型) 2.delib.correlation_tracker(生成追蹤器

1. cv2.dnn.readNetFromCaffe(prototxt, model)  用於進行SSD網絡的caffe框架的加載算法

參數說明:prototxt表示caffe網絡的結構文本,model表示已經訓練好的參數結果網絡

2.t=delib.correlation_tracker() 使用delib生成單目標的追蹤器app

3.delib.rectangle(int(box[0]), int(box[1]), int(box[2]), int(box[3])) 用於生成追蹤器所須要的矩形框[(startX, startY), (endX, endY)]框架

4.t.start_track(rgb, rect) # 初始化生成器的開始狀態ide

5.cv2.Writer(name, fourcc, (frame.shape[1], frame.shape[0]), True)進行圖片寫入到視頻裏面函數

參數說明: name表示視頻的名字,fourcc表示視頻格式,frame.shape[1] 表示視頻的長和寬, spa

6.cv2.dnn.blobFromImage(frame, 0.007843, (w, h), 127.5)  對圖像進行歸一化操做(-1, 1),線程

參數說明:frame表示輸入圖片,0.007843表示須要乘的數,即1/127.5,(w, h)表示圖像大小,127.5表示須要減去的數code

7. net.SetInput(rgb) 表示將圖片輸入到caffe網絡中orm

參數說明: rgb表示已經通過歸一化的圖片

8. net.forward() 輸出前向傳播的預測結果

9. oq = multiprocessing.Queue() 生成用於多進行傳輸過程當中的線程

10.p = multiprocessing.Process(target=start_track, args=(bb, label, rgb, iq, oq))  # 用於對函數建立進程

參數說明:target表示須要轉換爲進程的函數,args表示傳入到進程裏函數的參數

 

 

 

 

SSD是一種目標檢測的算法,其使用多個卷積層進行預測,原理在後續的博客中進行補充

對於目標追蹤的視頻,咱們先使用SSD找出圖片中人物的位置,而後使用dlib中的跟蹤器對物體進行跟蹤

因爲每個人物框對應一個跟蹤器,所以咱們能夠對每個跟蹤器起一個進程,使用輸入和輸出線程,用於構造多進程

使用的數據,須要一個訓練好的SSD權重參數,還須要caffe關於SSD的prototxt文件
代碼說明:

下面的代碼能夠近似認爲是由兩部分構成

 第一部分:使用SSD網絡進行預測,得到box的位置

第二部分:使用dlib構造tracker跟蹤器,帶入box構造帶有矩形框的追蹤器,而後使用dlib的追蹤器對圖像每一幀的位置進行追蹤

代碼:

第一步:構造進程函數,使用iq.get 和oq.put進行追蹤器的位置更新

第二步:構造輸入的參數, 使用cv2.dnn.readNetFromCaffe()構造SSD網絡模型

第三步:使用cv2.Videocapture視頻讀入,fps=FPS().start() 用於計算FPS

第四步:進入循環,使用.read()讀取圖片

第五步:使用cv2.resize()對圖片大小進行放縮變化,使用cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #將讀入的BGR轉換爲RGB,用於模型的預測

第六步:若是須要進行輸出,使用cv2.VideoWriter實例化視頻存儲器

第七步:若是尚未使用SSD得到矩形框,使用cv2.dnn.blobFromImage對圖像進行歸一化操做

第八步:使用net.setInput將圖片傳入,使用net.forward得到前向傳播輸出的結果

第九步:若是置信度大於給定的置信度,得到SSD的標籤,以及前向傳播的位置信息

第十步:使用multiprocessing.Queue構造線程iq和oq,將線程添加到列表中,使用multiprocessing.process構造多進程,用於分別創建單個跟蹤器

第十一步:若是已經生成了通道,使用iq.put(rgb)傳入圖像,使用oq.get()得到追蹤器更新的位置

第十二步:進行畫圖操做,若是存在writer就進行寫入

第十三步:更新fps.update

第十四步:統計運行的時間和FPS,並對vs進行釋放內存

import cv2
import numpy as np
import argparse
import dlib
import multiprocessing
from utils import FPS

# 第一步:構造追蹤器並進行結果的更新
def start_tracker(box, label, rgb, inputQueue, outputQueue):

    # 構造追蹤器
    t = dlib.correlation_tracker()
    # rect爲SSD得到的矩形框的位置
    rect = dlib.rectangle(int(box[0]), int(box[1]), int(box[2]), int(box[3]))
    # 設置追蹤器的初始位置
    t.start_track(rgb, rect)
    # 得到下一幀圖片
    while True:
        # 傳入的圖片
        rgb = inputQueue.get()
        if rgb is not None:
            # 更新追蹤器
            t.update(rgb)
            # 得到追蹤器的當前位置
            pos = t.get_position()

            startX = int(pos.left())
            startY = int(pos.top())
            endX = int(pos.right())
            endY = int(pos.bottom())

            # 把結果輸出放入到output裏面, 返回標籤和位置
            outputQueue.put((label, (startX, startY, endX, endY)))

# 第二步:設置參數,並使用cv2.dnn.readFrameCaffe構造SSD的網絡模型
ap = argparse.ArgumentParser()
ap.add_argument('-p', '--prototxt', default='mobilenet_ssd/MobileNetSSD_deploy.prototxt',
                help='path to caffe "deploy" prototxt file')
ap.add_argument('-m', '--model', default='mobilenet_ssd/MobileNetSSD_deploy.caffemodel',
                help='path to Caffe pre-trained model')
ap.add_argument('-v', '--video', default='race.mp4',
                help='path to input video file')
ap.add_argument('-o', '--output', type=str,
                help='path to optional output video file')
ap.add_argument('-c', '--confidence', type=float, default=0.2,
                help='minimu probability to filter weak detections')

args = vars(ap.parse_args())

# 用於存放輸入線程和輸出線程
inputQueues = []
outputQueues = []
# 21種分類的結果
CLASSES = ["background", "aeroplane", "bicycle", "bird", "boat",
    "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
    "dog", "horse", "motorbike", "person", "pottedplant", "sheep",
    "sofa", "train", "tvmonitor"]

print('[INFO] loading model...')
# 構造SSD網絡模型
net =cv2.dnn.readNetFromCaffe(args['prototxt'], args['model'])

print('[INFO] starting video stream...')
# 第三步:使用cv2.VideoCapture讀取視頻
vs = cv2.VideoCapture(args['video'])
writer = None

fps = FPS().start()

if __name__ == '__main__':
   # 第四步:進入循環,使用.read() 讀取圖片
    while True:
        ret, frame = vs.read()

        if frame is None:
            break

        # 第五步:進行圖像的維度變化, 而且將BGR轉換爲RGB格式
        h, w = frame.shape[:2]
        width = 600
        r = width / float(w)
        dim = (width, int(r*h))
        frame = cv2.resize(frame, dim, cv2.INTER_AREA)
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # 第六步:進行視頻的保存
        if args['output'] is not None and writer is None:
            fourcc = cv2.VideoWriter_fourcc(*'MJPG')
            writer = cv2.VideoWriter(args['output'], fourcc,
                                     (frame.shape[1], frame.shape[0]), True)

        # 若是輸入進程的維度爲0,進入循環首先檢測位置
        if len(inputQueues) == 0:
            # 第七步:使用cv2.dnn.blobFromImage()對圖片進行歸一化操做
            (h, w) = frame.shape[:2]
            blob = cv2.dnn.blobFromImage(frame, 0.007843, (w, h), 127.5)
            # 第八步:使用net.setInput輸入圖片,net.forward()得到前向傳播的結果
            net.setInput(blob)
            detections = net.forward()
            # 第九步:對結果進行循環,若是置信度大於閾值,則得到其標籤和box位置信息
            for i in np.arange(0, detections.shape[2]):
                confidence = detections[0, 0, i, 2]
                if confidence > args['confidence']:
                    idx = int(detections[0, 0, i, 1])
                    label = CLASSES[idx]

                    if label != 'person':
                        continue

                    box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
                    (startX, startY, endX, endY) = box.astype('int')
                    bb = (startX, startY, endX, endY)

                    # 第十步:建立輸入q和輸出q,建立process線程,使用process.start()啓動線程
                    iq = multiprocessing.Queue()
                    oq = multiprocessing.Queue()
                    inputQueues.append(iq)
                    outputQueues.append(oq)

                    # 在多個核上運行, 建立多核
                    p = multiprocessing.Process(
                        target=start_tracker,
                        args=(bb, label, rgb, iq, oq)
                    )
                    p.daemon = True
                    p.start()

                    cv2.rectangle(frame, (startX, startY), (endX, endY),
                                  (0, 255, 0), 2)
                    cv2.putText(frame, label, (startX, startY-15),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 2)


        else:
            # 第十一步:若是生成了進程,循環輸入線程,傳入rgb圖片,得到輸出線程的label和更新位置的輸出
            for iq in inputQueues:
                iq.put(rgb)

            for oq in outputQueues:
                (label, (startX, startY, endX, endY)) = oq.get()
                # 在frame圖像上繪製矩形框和text
                cv2.rectangle(frame, (startX, startY), (endX, endY),
                              (0, 255, 0), 2)
                cv2.putText(frame, label, (startX, startY - 15),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 2)
        # 第十二步:進行繪圖,並進行視頻的寫入操做
        if writer is not None:
            writer.write(frame)

        cv2.imshow('Frame', frame)
        key = cv2.waitKey(1) & 0xFF
        if key == 27:
            break
        # 第十三步 fps更新
        fps.update()
    # 第十四步:統計運行時間和FPS,釋放內存
    fps.stop()
    print('[INFO] elapsed time {:.2f}'.format(fps.elapsed()))
    print('[INFO] approx. FPS:{:.2f}'.format(fps.fps()))

    if writer is not None:
        writer.release()

    cv2.destroyAllWindows()
    vs.release()

效果展現:

FPS副代碼

import cv2
import numpy as np
import datetime


class FPS:

    def __init__(self):

        self._start = None
        self._end = None
        self._numFrames = 0

    def start(self):
        # start the timer
        self._start = datetime.datetime.now()
        return self

    def stop(self):
        # stop the timer
        self._end = datetime.datetime.now()

    def update(self):
        self._numFrames += 1

    def elapsed(self):
        return (self._end - self._start).total_seconds()

    def fps(self):
        return self._numFrames / self.elapsed()
相關文章
相關標籤/搜索