【目標檢測算法實現系列】Keras實現Faster R-CNN算法(一)

以前,咱們介紹了Fatser R-CNN模型,在接下來的幾篇文章,將經過Keras框架來完整實現Fatser R-CNN模型。數據集咱們採用經典的VOC數據集。python

這篇文章咱們主要看下相關數據的準備工做,具體流程以下:ios

1、VOC數據集解析

VOC數據集的下載,,由於官網下載太慢,文章末尾處有提供百度網盤下載bash

下載解壓後的文件目錄以下:網絡

對於目標檢測任務,只須要用到Annotations,ImageSets,JPEGImages這三個目錄。數據結構

1. Annotations:存放相關標註信息,每一張圖片對應一個xml文件,具體xml內容以下:app

<annotation>
  <folder>VOC2012</folder>
  <filename>2007_000033.jpg</filename>
  <source>
    <database>The VOC2007 Database</database>
    <annotation>PASCAL VOC2007</annotation>
    <image>flickr</image>
  </source>
  <size>
    <width>500</width>
    <height>366</height>
    <depth>3</depth>
  </size>
  <segmented>1</segmented>
  <object>
    <name>aeroplane</name>
    <pose>Unspecified</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>9</xmin>
      <ymin>107</ymin>
      <xmax>499</xmax>
      <ymax>263</ymax>
    </bndbox>
  </object>
  <object>
    <name>aeroplane</name>
    <pose>Left</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>421</xmin>
      <ymin>200</ymin>
      <xmax>482</xmax>
      <ymax>226</ymax>
    </bndbox>
  </object>
</annotation>複製代碼

2. ImageSets:咱們只會用到ImageSets\Main下train.txt , val.txt, test.txt這三個文件,裏面存儲對應訓練集,驗證集,測試集的圖片名稱,文件格式以下:框架


3. JPEGImages:存儲全部的圖片數據dom


咱們須要將下載來的VOC數據集解析成以下格式
機器學習

具體代碼實現以下ide

import os
import xml.etree.ElementTree as ET
from tqdm import tqdm
import pprint


def get_data(input_path):
    ''' :param input_path: voc數據目錄 :return: image_data:解析後的數據集 list列表 classes_count:一個字典數據結構,key爲對應類別名稱,value對應爲類別所對應的樣本(標註框)個數 classes_mapping:一個字典數據結構,key爲對應類別名稱,value爲對應類別的一個標識index '''
    image_data = []
    classes_count = {}  #一個字典,key爲對應類別名稱,value對應爲類別所對應的樣本(標註框)個數
    classes_mapping = {} #一個字典數據結構,key爲對應類別名稱,value爲對應類別的一個標識index

    data_paths = os.path.join(input_path, "VOC2012")
    print(data_paths)

    annota_path = os.path.join(data_paths, "Annotations")  # 數據標註目錄
    imgs_path = os.path.join(data_paths, "JPEGImages")  # 圖片目錄

    imgsets_path_train = os.path.join(data_paths, 'ImageSets', 'Main', 'train.txt')
    imgsets_path_val = os.path.join(data_paths, 'ImageSets', 'Main', 'val.txt')
    imgsets_path_test = os.path.join(data_paths, 'ImageSets', 'Main', 'test.txt')
    train_files = []  # 訓練集圖片名稱集合
    val_files = []  # 驗證集圖片名稱集合
    test_files = []  # 測試集圖片名稱集合

    with open(imgsets_path_train) as f:
        for line in f:
            # strip() 默認去掉字符串頭尾的空格和換行符
            train_files.append(line.strip() + '.jpg')

    with open(imgsets_path_val) as f:
        for line in f:
            val_files.append(line.strip() + '.jpg')

    # test-set not included in pascal VOC 2012
    if os.path.isfile(imgsets_path_test):
        with open(imgsets_path_test) as f:
            for line in f:
                test_files.append(line.strip() + '.jpg')

    # 得到全部的標註文件路徑,保存到annota_path_list列表中
    annota_path_list = [os.path.join(annota_path, s) for s in os.listdir(annota_path)]
    index = 0

    # Tqdm 是一個快速,可擴展的Python進度條,
    # 能夠在 Python 長循環中添加一個進度提示信息,用戶只須要封裝任意的迭代器 tqdm(iterator)
    annota_path_list = tqdm(annota_path_list)

    for annota_path in annota_path_list:
        exist_flag = False
        index += 1
        annota_path_list.set_description("Processing %s" % annota_path.split(os.sep)[-1])

        # 開始解析對應xml數據標註文件
        et = ET.parse(annota_path)
        element = et.getroot()
        element_objs = element.findall("object")  # 獲取全部的object子元素
        element_filename = element.find("filename").text  # 對應圖片名稱
        element_width = int(element.find("size").find("width").text)  # 對應圖片尺寸
        element_height = int(element.find("size").find("height").text)  # 對應圖片尺寸

        if (len(element_objs) > 0):
            annotation_data = {"filepath": os.path.join(imgs_path, element_filename),
                               "width": element_width,
                               "height": element_height,
                               "image_id": index,
                               "bboxes": []}  # bboxes 用來存放對應標註框的相關位置
        if element_filename in train_files:
            annotation_data["imageset"] = "train"
            exist_flag = True
        if element_filename in val_files:
            annotation_data["imageset"] = "val"
            exist_flag = True
        if len(test_files) > 0:
            if element_filename in test_files:
                annotation_data["imageset"] = "test"
                exist_flag = True

        if not exist_flag:
            continue

        for element_obj in element_objs:  # 遍歷一個xml標註文件中的全部標註框
            classes_name = element_obj.find("name").text  # 獲取當前標註框的類別名稱
            if classes_name in classes_count:  # classes_count 存儲類別以及對應類別的標註框個數
                classes_count[classes_name] += 1
            else:
                classes_count[classes_name] = 1

            if classes_name not in classes_mapping:
                classes_mapping[classes_name] = len(classes_mapping)

            obj_bbox = element_obj.find("bndbox")
            x1 = int(round(float(obj_bbox.find("xmin").text)))
            y1 = int(round(float(obj_bbox.find("ymin").text)))
            x2 = int(round(float(obj_bbox.find("xmax").text)))
            y2 = int(round(float(obj_bbox.find("ymax").text)))

            difficulty = int(element_obj.find("difficult").text) == 1
            annotation_data["bboxes"].append({"class": classes_name,
                                              "x1": x1, "x2": x2, "y1": y1, "y2": y2,
                                              "difficult": difficulty})
        image_data.append(annotation_data)

    return image_data, classes_count, classes_mapping複製代碼

複製代碼

2、數據加強

隨機將數據進行翻轉,旋轉,代碼以下:

def augment(img_data, config, augment = True):
    ''' 用來進行數據加強 :param img_data: 原始數據 :param config: 相關配置參數 :param augment: :return: 加強後的數據集 '''
    assert 'filepath' in img_data
    assert 'bboxes' in img_data
    assert 'width' in img_data
    assert 'height' in img_data

    img_data_aug = copy.deepcopy(img_data)

    img = cv2.imread(img_data_aug["filepath"])  #讀取原始圖片
    if augment:
        rows, cols = img.shape[:2]  #獲取圖像尺寸

        if config.use_horizontal_flips and np.random.randint(0,2) == 0:
            img = cv2.flip(img, 1)  #水平翻轉

            for bbox in img_data_aug["bboxes"]:  #從新更新每一個標註框橫座標的值
                x1 = bbox["x1"]
                x2 = bbox["x2"]
                bbox["x1"] = cols - x2
                bbox["x2"] = cols - x1

        if config.use_vertical_flips and np.random.randint(0,2) == 0:
            img = cv2.flip(img, 0)  #豎直翻轉

            for bbox in img_data_aug["bboxes"]:  #從新更新每一個標註框橫座標的值
                y1 = bbox["y1"]
                y2 = bbox["y2"]
                bbox["y1"] = rows - y2
                bbox["y2"] = rows - y1


        if config.rot_90:
            angle = np.random.choice([0,90,180,270],1)[0]
            print("angle==",angle)
            if angle == 270:   #旋轉270度
                img = np.transpose(img, (1,0,2))
                img = cv2.flip(img, 0)
            elif angle == 180:  #旋轉180度
                img = cv2.flip(img, -1)
            elif angle == 90:   #旋轉90度
                img = np.transpose(img, (1,0,2))
                img = cv2.flip(img, 1)
            elif angle == 0:
                pass

            # 從新更新每一個標註框橫座標的值
            for bbox in img_data_aug['bboxes']:
                x1 = bbox['x1']
                x2 = bbox['x2']
                y1 = bbox['y1']
                y2 = bbox['y2']
                if angle == 270:
                    bbox['x1'] = y1
                    bbox['x2'] = y2
                    bbox['y1'] = cols - x2
                    bbox['y2'] = cols - x1
                elif angle == 180:
                    bbox['x2'] = cols - x1
                    bbox['x1'] = cols - x2
                    bbox['y2'] = rows - y1
                    bbox['y1'] = rows - y2
                elif angle == 90:
                    bbox['x1'] = rows - y2
                    bbox['x2'] = rows - y1
                    bbox['y1'] = x1
                    bbox['y2'] = x2
                elif angle == 0:
                    pass
    img_data_aug['width'] = img.shape[1]
    img_data_aug['height'] = img.shape[0]
    return img_data_aug, img複製代碼

對於某個樣本,原始圖片以及標註框以下:

通過加強後的圖片以及標註框以下所示:



3、 爲RPN網絡準備訓練數據

咱們還須要將數據格式轉化爲RPN網絡能夠直接訓練的格式,生成全部的anchors。

首先,實現一個計算IOU的方法,代碼以下:

#計算兩個框以前的並集
def union(au, bu, area_intersection):
    # au和bu的格式爲: (x1,y1,x2,y2)
    # area_intersection爲 au 和 bu 兩個框的交集
    area_a = (au[2] - au[0]) * (au[3] - au[1])
    area_b = (bu[2] - bu[0]) * (bu[3] - bu[1])
    area_union = area_a + area_b - area_intersection
    return area_union
​
#計算兩個框以前的交集
def intersection(ai, bi):
    # ai和bi的格式爲: (x1,y1,x2,y2)
    x = max(ai[0], bi[0])
    y = max(ai[1], bi[1])
    w = min(ai[2], bi[2]) - x
    h = min(ai[3], bi[3]) - y
    if w < 0 or h < 0:
        return 0
    return w*h
​
​
# 計算兩個框的iou值
def iou(a, b):
    # a和b的格式爲: (x1,y1,x2,y2)if a[0] >= a[2] or a[1] >= a[3] or b[0] >= b[2] or b[1] >= b[3]:
        return 0.0
​
    area_i = intersection(a, b)  #計算交集
    area_u = union(a, b, area_i)  #計算並集return float(area_i) / float(area_u + 1e-6)   #交併比複製代碼

接下來,咱們看下若是針對每一張圖片構造全部anchor,併爲RPN網絡準備訓練數據。

經過以前的Faster RCNN介紹,咱們知道RPN網絡有兩個輸出,一個是檢測框分類層輸出,輸出的通道個數爲2*k(若是使用sigmoid分類的話,那就是k個),另外一個爲檢測框迴歸層輸出,輸出的通道個數爲4*k。因此,咱們也須要先對咱們的數據標籤作對應處理。具體代碼以下:

def getdata_for_rpn(config, img_data, width, heigth, resized_width, resized_height):
    ''' 用於提取RPN網絡訓練集,也就是產生各類anchors以及anchors對應與ground truth的修正參數 :param C: 配置信息 :param img_data: 原始數據 :param width: 縮放前圖片的寬 :param heigth: 縮放前圖片的高 :param resized_width: 縮放後圖片的寬 :param resized_height: 縮放後圖片的高 :param img_length_calc_function: 獲取通過base Net後提取出來的featur map圖像尺寸, 對於VGG16來講,就是在原始圖像尺寸上除以16 :return: '''
    downscale = float(config.rpn_stride)   #原始圖像到feature map之間的縮放映射關係
    anchor_sizes = config.anchor_box_scales   #anchor 三種尺寸
    anchor_ratios = config.anchor_box_ratios  # anchor 三種寬高比
    num_anchors = len(anchor_sizes) * len(anchor_ratios)  # 每個滑動窗口所對應的anchor個數,也就是論文中的k值

    #計算出通過base Net後提取出來的featurmap圖像尺寸
    output_width = int(resized_width / 16)
    output_height = int(resized_height / 16)

    # (36,36,9),用來存放RPN網絡,訓練樣本最後分類層輸出時的y值,
    # 最後一維9表明對於每一個像素點對應9個anchor,值爲0或1(正樣本或負樣本)
    y_rpn_overlap = np.zeros((output_height, output_width, num_anchors))

    #(36,36,9),用來存放對於每一個anchor,是不是有效的樣本,值爲0或者1(無效樣本,有效樣本)
    # 由於對於iou在0.3到0.7之間的樣本,是直接丟棄 ,不參與訓練的
    # 另外,只是在一張圖片中隨機採樣256個anchor,其餘的也不參與訓練
    y_is_box_valid = np.zeros((output_height, output_width, num_anchors))

    # (36,36,9*4),用來存放RPN網絡,針對一張圖片,最後迴歸層的標籤Y值
    y_rpn_regr = np.zeros((output_height, output_width, num_anchors * 4))

    #獲取一張訓練圖片的真實標註框個數,也就是含有的待檢測的目標個數
    num_bboxes = len(img_data['bboxes'])

    # 用來存儲每一個bbox(真實標註框)所對應的anchor個數
    num_anchors_for_bbox = np.zeros(num_bboxes).astype(int)

    # 用來存儲每一個bbox(真實標註框)所對應的最優anchor在feature map中的位置信息,以及大小信息
    # [jy, ix, anchor_ratio_idx, anchor_size_idx]
    best_anchor_for_bbox = -1 * np.ones((num_bboxes, 4)).astype(int)

    # 每一個bbox(真實標註框)與全部anchor 的最優IOU值
    best_iou_for_bbox = np.zeros(num_bboxes).astype(np.float32)

    # 用來存儲每一個bbox(真實標註框)所對應的最優anchor的座標值
    best_x_for_bbox = np.zeros((num_bboxes, 4)).astype(int)

    # 用來存儲每一個bbox(真實標註框)與所對應的最優anchor之間的4個平移縮放參數,用於迴歸預測
    best_dx_for_bbox = np.zeros((num_bboxes, 4)).astype(np.float32)

    gta = np.zeros((num_bboxes, 4))  # 用來存放通過縮放後的標註框
    # 由於以前圖片進行了縮放,因此須要將對應的標註框作對應調整
    for bbox_num, bbox in enumerate(img_data["bboxes"]):
        gta[bbox_num, 0] = bbox["x1"] * (resized_width / float(width))
        gta[bbox_num, 1] = bbox["x2"] * (resized_width / float(width))
        gta[bbox_num, 2] = bbox["y1"] * (resized_height / float(heigth))
        gta[bbox_num, 3] = bbox["y2"] * (resized_height / float(heigth))


    #遍歷feature map上的每個像素點
    for ix in range(output_width):
        for iy in range(output_height):
            #在feature map的每個像素點上,遍歷對應不一樣大小,不一樣長寬比的k(9)個anchor
            for anchor_size_index in range(len(anchor_sizes)):
                for anchor_ratio_index in range(len(anchor_ratios)):
                    anchor_x = anchor_sizes[anchor_size_index] * anchor_ratios[anchor_ratio_index][0]
                    anchor_y = anchor_sizes[anchor_size_index] * anchor_ratios[anchor_ratio_index][1]

                    # 得到當前anchor在原圖上的X座標位置
                    # downscale * (ix + 0.5)即爲當前anchor在原始圖片上的中心點X座標
                    # downscale * (ix + 0.5) - anchor_x / 2 即爲當前anchor左上點X座標
                    # downscale * (ix + 0.5) + anchor_x / 2 即爲當前anchor右下點X座標
                    x1_anc = downscale * (ix + 0.5) - anchor_x / 2
                    x2_anc = downscale * (ix + 0.5) + anchor_x / 2
                    # 去掉那些跨過圖像邊界的框
                    if x1_anc<0 or x2_anc > resized_width:
                        continue

                    # 得到當前anchor在原圖上的Y座標位置
                    # downscale * (jy + 0.5)即爲當前anchor在原始圖片上的中心點Y座標
                    # downscale * (jy + 0.5) - anchor_y / 2 即爲當前anchor左上點Y座標
                    # downscale * (jy + 0.5) + anchor_y / 2 即爲當前anchor右下點Y座標
                    y1_anc = downscale * (iy + 0.5) - anchor_y / 2
                    y2_anc = downscale * (iy + 0.5) + anchor_y / 2
                    # 去掉那些跨過圖像邊界的框
                    if y1_anc < 0 or y2_anc > resized_height:
                        continue

                    # 用來存放當前anchor類別是前景(正樣本)仍是背景(負樣本)
                    bbox_type = "neg"
                    # best_iou_for_loc 是用來存儲當前anchor針對於全部真實標註框的一個最優iou
                    # 須要與前面的best_iou_for_bbox 每一個真實標註框 針對於全部 anchor 的最優iou是不同的。
                    best_iou_for_loc = 0.0

                    #遍歷全部真實標註框,也就是全部ground truth
                    for bbox_num in range(num_bboxes):
                        # 計算當前anchor與當前真實標註框的iou值
                        curr_iou = iou([gta[bbox_num, 0], gta[bbox_num, 2], gta[bbox_num, 1], gta[bbox_num, 3]],
                                       [x1_anc, y1_anc, x2_anc, y2_anc])

                        #根據iou值,判斷當前anchor是否爲正樣本。
                        # 若是是,則計算此anchor(正樣本)到ground - truth(真實檢測框)的對應4個平移縮放參數。
                        # 判斷一個anchor是否爲正樣本的兩個條件爲:
                        # 1.與ground - truth(真實檢測框)IOU最高的anchor
                        # 2.與任意ground - truth(真實檢測框)的IOU大於0.7 的anchor
                        if curr_iou > best_iou_for_bbox[bbox_num] or curr_iou > config.rpn_max_overlap:
                            # cx,cy: ground-truth(真實檢測框)的中心點座標
                            cx = (gta[bbox_num, 0] + gta[bbox_num, 1]) / 2.0
                            cy = (gta[bbox_num, 2] + gta[bbox_num, 3]) / 2.0
                            # cxa,cya: 當前anchor的中心點座標
                            cxa = (x1_anc + x2_anc) / 2.0
                            cya = (y1_anc + y2_anc) / 2.0

                            # (tx, ty, tw, th)即爲此anchor(正樣本)到ground-truth(真實檢測框)的對應4個平移縮放參數
                            tx = (cx - cxa) / (x2_anc - x1_anc)
                            ty = (cy - cya) / (y2_anc - y1_anc)
                            tw = np.log((gta[bbox_num, 1] - gta[bbox_num, 0]) / (x2_anc - x1_anc))
                            th = np.log((gta[bbox_num, 3] - gta[bbox_num, 2]) / (y2_anc - y1_anc))

                        if img_data["bboxes"][bbox_num]["class"] != "bg":
                            #針對於當前ground - truth(真實檢測框),若是當前anchor與之的iou最大,則從新更新相關存儲的best值
                            if curr_iou > best_iou_for_bbox[bbox_num]:
                                best_anchor_for_bbox[bbox_num] = [iy, ix, anchor_ratio_index, anchor_size_index]
                                best_iou_for_bbox[bbox_num] = curr_iou
                                best_x_for_bbox[bbox_num,:] = [x1_anc, x2_anc, y1_anc, y2_anc]
                                best_dx_for_bbox[bbox_num,:] = [tx, ty, tw, th]

                            #對於iou大於0.7的,則,不管是不是最優的,直接認爲是正樣本
                            if curr_iou > config.rpn_max_overlap:
                                bbox_type = "pos"
                                num_anchors_for_bbox[bbox_num] +=1
                                if curr_iou > best_iou_for_loc:
                                    best_iou_for_loc = curr_iou
                                    best_regr = (tx, ty, tw, th)#當前anchor與和它有最優iou的那個ground-truth(真實檢測框)之間的對應4個平移參數
                            # iou值大於0.3,小於0.7的的,即不是正樣本,也不是負樣本
                            if config.rpn_min_overlap < curr_iou < config.rpn_max_overlap:
                                if bbox_type != 'pos':
                                    bbox_type = 'neutral'
                    if bbox_type == "neg":
                        test_index = anchor_size_index * len(anchor_ratios) + anchor_ratio_index
                        y_is_box_valid[iy, ix, anchor_size_index * len(anchor_ratios) + anchor_ratio_index] = 1
                        y_rpn_overlap[iy, ix, anchor_size_index * len(anchor_ratios) + anchor_ratio_index] = 0
                    elif bbox_type == "neutral":
                        y_is_box_valid[iy, ix, anchor_size_index * len(anchor_ratios) + anchor_ratio_index] = 0
                        y_rpn_overlap[iy, ix, anchor_size_index * len(anchor_ratios) + anchor_ratio_index] = 0
                    elif bbox_type == "pos":
                        y_is_box_valid[iy, ix, anchor_size_index * len(anchor_ratios) + anchor_ratio_index] = 1
                        y_rpn_overlap[iy, ix, anchor_size_index * len(anchor_ratios) + anchor_ratio_index] = 1
                        start = 4 * (anchor_size_index * len(anchor_ratios) + anchor_ratio_index)
                        y_rpn_regr[iy, ix, start:start+4] = best_regr

    # 通過上面,咱們只是挑選出了 與任意ground - truth(真實檢測框)的IOU大於0.7 的anchor爲正樣本。
    # 可是若是某個ground - truth(真實檢測框) 沒有與它iou值大於0.7的anchor呢?
    # 這個時候須要用到第一個條件 與ground - truth(真實檢測框)IOU最高的anchor
    # 咱們須要確保每個真實標註框都有至少一個對應的正樣本(anchor)
    for idx in range(num_anchors_for_bbox.shape[0]):#遍歷全部真實標註框,也就是全部ground truth
        if num_anchors_for_bbox[idx] == 0:  #若是當前真實標註框沒有所對應的anchor
            if best_anchor_for_bbox[idx, 0] == -1: #若是當前真實標註框沒有與任何anchor都無交集,也就是說iou都等於0,則直接忽略掉
                continue
            y_is_box_valid[best_anchor_for_bbox[idx, 0], best_anchor_for_bbox[idx, 1],
                           best_anchor_for_bbox[idx, 3] * len(anchor_ratios) + best_anchor_for_bbox[idx,2]] = 1
            y_rpn_overlap[best_anchor_for_bbox[idx, 0], best_anchor_for_bbox[idx, 1],
                           best_anchor_for_bbox[idx, 3] * len(anchor_ratios) + best_anchor_for_bbox[idx, 2]] = 1
            start = 4 * (best_anchor_for_bbox[idx, 3] * len(anchor_ratios) + best_anchor_for_bbox[idx, 2])
            y_rpn_regr[best_anchor_for_bbox[idx, 0], best_anchor_for_bbox[idx, 1], start:start+4] \
                = best_dx_for_bbox[idx, :]

    #增長一維,(樣本個數)
    y_rpn_overlap = np.expand_dims(y_rpn_overlap, axis=0)
    y_is_box_valid = np.expand_dims(y_is_box_valid, axis=0)
    y_rpn_regr = np.expand_dims(y_rpn_regr, axis=0)


    ''' a = np.array([[0,1,0], [1,0,1]]) b = np.array([[1,1,0], [0,1,1]]) print(np.logical_and(a, b)) # [[False True False] # [False False True]] print(np.where(np.logical_and(a, b))) #(array([0, 1], dtype=int64), array([1, 2], dtype=int64)) '''
    #np.asarray(condition).nonzero()
    #pos_locs 正樣本對應的三個維度的下標
    pos_locs = np.where(np.logical_and(y_rpn_overlap[0, :, :, :] == 1, y_is_box_valid[0, :, :, :] == 1))
    #neg_locs 負樣本對應的三個維度的下標
    neg_locs = np.where(np.logical_and(y_rpn_overlap[0, :, :, :] == 0, y_is_box_valid[0, :, :, :] == 1))

    num_pos = len(pos_locs[0])  #正樣本個數

    # 隨機採樣256個樣本做爲一個mini-batch,而且最多保持正負樣本比例1:1,若是正樣本個數不夠,用負樣本填充
    mini_batch = 256
    if len(pos_locs[0]) > mini_batch / 2:  #判斷正樣本個數是否多於128,若是是,則從全部正樣本中隨機採用128個
        # 注意這塊,是從全部正例的下標中留下128個,選取出其餘剩餘的,將對應的y_is_box_valid設置爲0,
        # 也就是說選取出來的正例樣本就是丟棄的, 不進行訓練的樣本,剩餘的128個即爲實際的正例樣本
        val_locs = random.sample(range(num_pos), num_pos - mini_batch / 2)
        y_is_box_valid[0, pos_locs[0][val_locs], pos_locs[1][val_locs], pos_locs[2][val_locs]] = 0

        num_pos = mini_batch / 2

    # 正樣本選取完畢後,開始選取負例樣本,一樣的思路,隨機選取出不須要的負樣本,將對應的y_is_box_valid設置爲0,
    # 剩餘的正好和正樣本組合成 256個樣本
    if len(neg_locs[0]) + num_pos > mini_batch:
        #(mini_batch-num_pos) : 須要的負例樣本數
        #len(neg_locs[0]) - (mini_batch-num_pos):不須要的負例樣本數
        val_locs = random.sample(range(len(neg_locs[0])), len(neg_locs[0]) - (mini_batch-num_pos))
        y_is_box_valid[0, neg_locs[0][val_locs], neg_locs[1][val_locs], neg_locs[2][val_locs]] = 0

    #將y_is_box_valid 與 y_rpn_overlap鏈接到一塊
    y_rpn_cls = np.concatenate([y_is_box_valid, y_rpn_overlap], axis=3)

    #對於迴歸損失來講,只是針對正樣本進行計算的,負樣本和不參與訓練的其餘樣本都須要過濾掉,不參與訓練
    #因此這塊須要將y_rpn_overlap 和 y_rpn_regr拼接起來做爲RPN網絡迴歸層的真實Y值,方便後續計算損失函數
    y_rpn_regr = np.concatenate([np.repeat(y_rpn_overlap, 4, axis=3), y_rpn_regr], axis=3)


    return np.copy(y_rpn_cls), np.copy(y_rpn_regr)複製代碼

咱們構建一個迭代器來方便在訓練時,實時生成對應訓練數據

def get_anchor_data_gt(img_datas, class_count, C, mode="train"):
    ''' 生成用於RPN網絡訓練數據集的迭代器 :param img_data: 原始數據,list,每一個元素都是一個字典類型,存放着每張圖片的相關信息 all_img_data[0] = {'width': 500, 'height': 500, 'bboxes': [{'y2': 500, 'y1': 27, 'x2': 183, 'x1': 20, 'class': 'person', 'difficult': False}, {'y2': 500, 'y1': 2, 'x2': 249, 'x1': 112, 'class': 'person', 'difficult': False}, {'y2': 490, 'y1': 233, 'x2': 376, 'x1': 246, 'class': 'person', 'difficult': False}, {'y2': 468, 'y1': 319, 'x2': 356, 'x1': 231, 'class': 'chair', 'difficult': False}, {'y2': 450, 'y1': 314, 'x2': 58, 'x1': 1, 'class': 'chair', 'difficult': True}], 'imageset': 'test', 'filepath': './datasets/VOC2007/JPEGImages/000910.jpg'} :param class_count: 數據集中各個類別的樣本個數,字典型 :param C: 相關配置參數 :param mode: :return: 返回一個數據迭代器 '''
    while True:
        if mode == "train":
            #打亂數據集
            random.shuffle(img_datas)

        for img_data in img_datas:
            try:
                #數據加強
                if mode == "train":
                    img_data_aug, x_img = data_augment.augment(img_data, C, augment=True)
                else:
                    img_data_aug, x_img = data_augment.augment(img_data, C, augment=False)

                #確保圖像尺寸不發生改變
                (width, height) = (img_data_aug['width'], img_data_aug['height'])
                (rows, cols, _) = x_img.shape
                assert cols == width
                assert rows == height

                #將圖像的短邊縮放到600尺寸
                (resized_width, resized_height) = get_new_img_size(width, height, C.im_size)
                x_img = cv2.resize(x_img, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC)

                x_img = cv2.cvtColor(x_img, cv2.COLOR_BGR2RGB)
                x_img = x_img.astype(np.float32)
                x_img[:, :, 0] -= C.img_channel_mean[0]
                x_img[:, :, 1] -= C.img_channel_mean[1]
                x_img[:, :, 2] -= C.img_channel_mean[2]
                x_img /= C.img_scaling_factor
                x_img = np.expand_dims(x_img, axis=0)

                y_rpn_cls, y_rpn_regr = getdata_for_rpn(C, img_data_aug, width, height, resized_width, resized_height)

                y_rpn_regr[:,:, :, y_rpn_regr.shape[1] // 2:] *= C.std_scaling

                yield  np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug
            except Exception as e:
                print(e)
                continue複製代碼

測試下看看:

image_data, classes_count, classes_mapping = voc_data_parser.get_data("data")
train_imgs = [s for s in image_data if s['imageset'] == 'train']   #訓練集
val_imgs = [s for s in image_data if s['imageset'] == 'val']  #驗證集
test_imgs = [s for s in image_data if s['imageset'] == 'test'] #測試集
data_gen_train = data_generators.get_anchor_data_gt(train_imgs[:3], classes_count, config)
for i in range(3):
    X, Y, img_data = next(data_gen_train)
    print("通過預處理後的圖像X:",X.shape)
    print("RPN網絡分類層對應Y值:",Y[0].shape)
    print("RPN網絡迴歸層層對應Y值:",Y[1].shape)複製代碼

結果以下:

這塊須要注意,RPN網絡的分類層,由於是二分類,後面咱們直接用sigmoid分類,對於sigmoid分類來講,由於每一個feature map像素點中對應9個anchor, 最終的輸出通道數應該爲1*k也就是9,爲何這塊是18呢?

由於咱們並非全部anchor都參與訓練的,首先,對於一部分anchor,根據Fatster R-CNN中根據IOU的判斷標準,一部分anchor既不屬於正樣本,也不屬於負樣本,直接捨棄,另外,對應選出來的正負樣本中,最終也只是選取出256個anchor做爲實際參與訓練的樣本。因此,對於每個anchor,咱們都須要一個標記來判斷當前anchor是否參與訓練。

對於迴歸層的Y值最後的通道數,正常狀況下,應該爲36,這塊爲什麼是72呢,由於,以前咱們在說faster RCNN原理時,關於RPN網絡的損失函數,對於迴歸損失來講,只是針對正樣本進行計算的,負樣本和不參與訓練的其餘樣本都須要過濾掉,不參與訓練,全部這塊須要在本來36的基礎上對應加上每個是不是正樣本的標記,全部最終爲72.

數據準備好以後,下次咱們來構造Faster RCNN網絡模型。

未完待續....


相關本章完整代碼以及VOC2102數據集百度網盤下載,請關注我本身的公衆號 AI計算機視覺工坊,回覆【代碼】和【數據集】獲取。本公衆號不按期推送機器學習,深度學習,計算機視覺等相關文章,歡迎你們和我一塊兒學習,交流。


                                

相關文章
相關標籤/搜索