SSD源碼解讀——數據讀取

以前,對SSD的論文進行了解讀,能夠回顧以前的博客:http://www.javashuo.com/article/p-rrxantwu-cm.htmlhtml

爲了加深對SSD的理解,所以對SSD的源碼進行了復現,主要參考的github項目是ssd.pytorch。同時,我本身對該項目增長了大量註釋:https://github.com/Dengshunge/mySSD_pytorchgit

搭建SSD的項目,能夠分紅如下三個部分:github

  1. 數據讀取;
  2. 網絡搭建
  3. 損失函數的構建
  4. 網絡測試

接下來,本篇博客重點分析數據讀取網絡


1、總體框架

SSD的數據讀取環節,一樣適用於大部分目標檢測的環節,具備通用性。爲了方便理解,本項目以VOC2007+2012爲例。所以,數據讀取環節,一般是按照如下步驟展開進行:app

  1. 函數入口;
  2. 圖片的讀取和xml文件的讀取;
  3. 對GT框進行處理;
  4. 數據加強;
  5. 輔助函數。

2、具體實現細節

2.1 函數入口

數據讀取的函數入口在train.py文件中:框架

if args.dataset == 'VOC':
    train_dataset = VOCDetection(root=args.dataset_root)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, num_workers=4,
        collate_fn=detection_collate, shuffle=True, pin_memory=True)

能夠看到,首先經過函數 VOCDetection() 來對VOC數據集進行初始化,再使用函數 DataLoader() 來實現對數據集的讀取。這一步與常見的分類網絡形式相同,但不一樣的是,多了collate_fn這一參數,後續會對此進行說明。dom

2.2 圖片與xml文件讀取

首先,咱們先看看函數VOCDetection() 的初始化函數__init__()。在__init__中包含了須要傳入的幾個參數,image_sets(表示VOC使用到的數據集),transform(數據加強的方式),target_transform(GT框的處理方式)。函數

class VOCDetection():
    """VOC Detection Dataset Object

    input is image, target is annotation

    Arguments:
        root (string): filepath to VOCdevkit folder.
        image_set (string): imageset to use (eg. 'train', 'val', 'test')
        transform (callable, optional): transformation to perform on the input image
            圖片預處理的方式,這裏使用了大量數據加強的方式
        target_transform (callable, optional): transformation to perform on the
            target `annotation`
            (eg: take in caption string, return tensor of word indices)
            真實框預處理的方式
    """

    def __init__(self, root,
                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
                 transform=SSDAugmentation(size=config.voc['min_dim'], mean=config.MEANS),
                 target_transform=VOCAnnotationTransform()):
        self.root = root
        self.image_set = image_sets
        self.transform = transform
        self.target_transform = target_transform
        self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
        self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = []
        # 使用VOC2007和VOC2012的train做爲訓練集
        for (year, name) in self.image_set:
            rootpath = os.path.join(self.root, 'VOC' + year)
            for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                self.ids.append([rootpath, line[:-1]])

首先,爲何image_sets是這樣的形式呢?由於VOC具備固定的文件夾路徑,利用這個參數和配合路徑讀取,能夠讀取到txt文件,該txt文件用於制定哪些圖片用於訓練。此外,還須要設置參數self.ids,這個list用於存儲文件的路徑,由兩列組成,"VOC/2007"和圖片名稱。經過這兩個參數,後續能夠配合函數_annopath()和_imgpath()能夠讀取到對應圖片的路徑和xml文件。測試

在pytorch中,還須要相應的函數來對讀取圖片與返回結果,以下所示。其中,重點是pull_iterm函數。spa

    def __getitem__(self, index):
        im, gt = self.pull_item(index)
        return im, gt

    def __len__(self):
        return len(self.ids)

    def pull_item(self, index):
        img_id = tuple(self.ids[index])
        # img_id裏面有2個值
        target = ET.parse(self._annopath % img_id).getroot()  # 得到xml的內容,但這個是具備特殊格式的
        img = cv2.imread(self._imgpath % img_id)
        height, width, _ = img.shape

        if self.target_transform is not None:
            # 真實框處理
            target = self.target_transform(target, width, height)

        if self.transform is not None:
            # 圖像預處理,進行數據加強,只在訓練進行數據加強,測試的時候不須要
            target = np.array(target)
            img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
            # 轉換格式
            img = img[:, :, (2, 1, 0)]  # to rbg
            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        return torch.from_numpy(img).permute(2, 0, 1), target

該函數pull_item(),首先讀取圖片和相應的xml文件;接着對使用類VOCAnnotationTransform來對GT框進行處理,即讀取GT框座標與將座標歸一化;而後經過函數SSDAugmentation()對圖片進行數據加強;最後對對圖片進行常規處理(交換通道等),返回圖片與存有GT框的list。

2.3 對GT框進行處理

接着,須要講一講這個類VOCAnnotationTransform的做用,其定義以下。self.class_to_ind是一個map,其key是類別名稱,value是編號,這個對象的做用是,讀取xml時,能將對應的類別名稱轉換成label;在__call__()函數中,主要是xml讀取的一些方式,值得一提的是,GT框的最錶轉換成了[0,1]之間,當圖片尺寸變化了,GT框的座標也能進行相應的變換。最後,res的每行由5個元素組成,分別是[x_min,y_min,x_max,y_max,label]。

class VOCAnnotationTransform():
    '''
    獲取xml裏面的座標值和label,並將座標值轉換成0到1
    '''

    def __init__(self, class_to_ind=None, keep_difficult=False):
        # 將類別名字轉換成數字label
        self.class_to_ind = class_to_ind or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        # 在xml裏面,有個difficult的參數,這個表示特別難識別的目標,通常是小目標或者遮擋嚴重的目標
        # 所以,能夠經過這個參數,忽略這些目標
        self.keep_difficult = keep_difficult

    def __call__(self, target, width, height):
        '''
        將一張圖裏麪包含若干個目標,獲取這些目標的座標值,並轉換成0到1,並獲得其label
        :param target: xml格式
        :return: 返回List,每一個目標對應一行,每行包括5個參數[xmin, ymin, xmax, ymax, label_ind]
        '''
        res = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1  # 判斷該目標是否爲難例
            # 判斷是否跳過難例
            if not self.keep_difficult and difficult:
                continue
            name = obj.find('name').text.lower().strip()  # text是得到目標的名稱,lower將字符轉換成小寫,strip去除先後空格
            bbox = obj.find('bndbox')  # 得到真實框座標

            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1  # 得到座標值
                # 將座標轉換成[0,1],這樣圖片尺寸發生變化的時候,真實框也隨之變化,即平移不變形
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            label_idx = self.class_to_ind[name]  # 得到名字對應的label
            bndbox.append(label_idx)
            res.append(bndbox)  # [xmin, ymin, xmax, ymax, label_ind]
        return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

2.4 數據加強

還有一個重要的函數,即函數SSDAugmentation(),該函數的做用是做數據加強。論文中也說起了,數據加強對最終的結果提高有着重大做用。博客1博客2具體講述了數據加強的源碼,講得十分詳細。在本項目中,SSDAugmentation()函數在data/augmentations.py中,以下所示。因爲opencv讀取讀片的時候,取值範圍是[0,255],是int類型,須要將其轉換爲float類型,計算其GT框的正式座標。而後對圖片進行光度變形,包含改變對比度,改變飽和度,改變色調、改變亮度和增長噪聲等。接着有對圖片進行擴張和裁剪等。在此操做中,會涉及到GT框座標的變換。最後,當上述變化處理完後,再對GT框座標歸一化,和resize圖片,減去均值等。具體細節,能夠參考兩篇博客進行解讀。

class SSDAugmentation(object):
    def __init__(self, size=300, mean=(104, 117, 123)):
        self.mean = mean
        self.size = size
        self.augment = Compose([
            ConvertFromInts(),  # 將圖片從int轉換成float
            ToAbsoluteCoords(),  # 計算真實的錨點框座標
            PhotometricDistort(),  # 光度變形
            Expand(self.mean),  # 隨機擴張圖片
            RandomSampleCrop(),  # 隨機裁剪
            RandomMirror(),  # 隨機鏡像
            ToPercentCoords(),
            Resize(self.size),
            SubtractMeans(self.mean)
        ])

    def __call__(self, img, boxes, labels):
        return self.augment(img, boxes, labels)

2.5 輔助函數

在一個batch中,每張圖片的GT框數量是不等的,所以,須要定義一個函數來處理這種狀況。函數detection_collate()就是用於處理這種狀況,使得一張圖片能對應一個list,這裏list裏面有全部GT框的信息組成。

def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).
    自定義處理在同一個batch,含有不一樣數量的目標框的狀況

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on
                                 0 dim
    """
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(sample[0])
        targets.append(torch.FloatTensor(sample[1]))
    return torch.stack(imgs, 0), targets

 


 

至此,已經將SSD的數據讀取部分分析完。 

相關文章
相關標籤/搜索