以前,對SSD的論文進行了解讀,能夠回顧以前的博客:http://www.javashuo.com/article/p-rrxantwu-cm.html。html
爲了加深對SSD的理解,所以對SSD的源碼進行了復現,主要參考的github項目是ssd.pytorch。同時,我本身對該項目增長了大量註釋:https://github.com/Dengshunge/mySSD_pytorchgit
搭建SSD的項目,能夠分紅如下三個部分:github
接下來,本篇博客重點分析數據讀取。網絡
SSD的數據讀取環節,一樣適用於大部分目標檢測的環節,具備通用性。爲了方便理解,本項目以VOC2007+2012爲例。所以,數據讀取環節,一般是按照如下步驟展開進行:app
數據讀取的函數入口在train.py文件中:框架
if args.dataset == 'VOC': train_dataset = VOCDetection(root=args.dataset_root) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, num_workers=4, collate_fn=detection_collate, shuffle=True, pin_memory=True)
能夠看到,首先經過函數 VOCDetection() 來對VOC數據集進行初始化,再使用函數 DataLoader() 來實現對數據集的讀取。這一步與常見的分類網絡形式相同,但不一樣的是,多了collate_fn這一參數,後續會對此進行說明。dom
首先,咱們先看看函數VOCDetection() 的初始化函數__init__()。在__init__中包含了須要傳入的幾個參數,image_sets(表示VOC使用到的數據集),transform(數據加強的方式),target_transform(GT框的處理方式)。函數
class VOCDetection(): """VOC Detection Dataset Object input is image, target is annotation Arguments: root (string): filepath to VOCdevkit folder. image_set (string): imageset to use (eg. 'train', 'val', 'test') transform (callable, optional): transformation to perform on the input image 圖片預處理的方式,這裏使用了大量數據加強的方式 target_transform (callable, optional): transformation to perform on the target `annotation` (eg: take in caption string, return tensor of word indices) 真實框預處理的方式 """ def __init__(self, root, image_sets=[('2007', 'trainval'), ('2012', 'trainval')], transform=SSDAugmentation(size=config.voc['min_dim'], mean=config.MEANS), target_transform=VOCAnnotationTransform()): self.root = root self.image_set = image_sets self.transform = transform self.target_transform = target_transform self._annopath = os.path.join('%s', 'Annotations', '%s.xml') self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg') self.ids = [] # 使用VOC2007和VOC2012的train做爲訓練集 for (year, name) in self.image_set: rootpath = os.path.join(self.root, 'VOC' + year) for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')): self.ids.append([rootpath, line[:-1]])
首先,爲何image_sets是這樣的形式呢?由於VOC具備固定的文件夾路徑,利用這個參數和配合路徑讀取,能夠讀取到txt文件,該txt文件用於制定哪些圖片用於訓練。此外,還須要設置參數self.ids,這個list用於存儲文件的路徑,由兩列組成,"VOC/2007"和圖片名稱。經過這兩個參數,後續能夠配合函數_annopath()和_imgpath()能夠讀取到對應圖片的路徑和xml文件。測試
在pytorch中,還須要相應的函數來對讀取圖片與返回結果,以下所示。其中,重點是pull_iterm函數。spa
def __getitem__(self, index): im, gt = self.pull_item(index) return im, gt def __len__(self): return len(self.ids) def pull_item(self, index): img_id = tuple(self.ids[index]) # img_id裏面有2個值 target = ET.parse(self._annopath % img_id).getroot() # 得到xml的內容,但這個是具備特殊格式的 img = cv2.imread(self._imgpath % img_id) height, width, _ = img.shape if self.target_transform is not None: # 真實框處理 target = self.target_transform(target, width, height) if self.transform is not None: # 圖像預處理,進行數據加強,只在訓練進行數據加強,測試的時候不須要 target = np.array(target) img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) # 轉換格式 img = img[:, :, (2, 1, 0)] # to rbg target = np.hstack((boxes, np.expand_dims(labels, axis=1))) return torch.from_numpy(img).permute(2, 0, 1), target
該函數pull_item(),首先讀取圖片和相應的xml文件;接着對使用類VOCAnnotationTransform來對GT框進行處理,即讀取GT框座標與將座標歸一化;而後經過函數SSDAugmentation()對圖片進行數據加強;最後對對圖片進行常規處理(交換通道等),返回圖片與存有GT框的list。
接着,須要講一講這個類VOCAnnotationTransform的做用,其定義以下。self.class_to_ind是一個map,其key是類別名稱,value是編號,這個對象的做用是,讀取xml時,能將對應的類別名稱轉換成label;在__call__()函數中,主要是xml讀取的一些方式,值得一提的是,GT框的最錶轉換成了[0,1]之間,當圖片尺寸變化了,GT框的座標也能進行相應的變換。最後,res的每行由5個元素組成,分別是[x_min,y_min,x_max,y_max,label]。
class VOCAnnotationTransform(): ''' 獲取xml裏面的座標值和label,並將座標值轉換成0到1 ''' def __init__(self, class_to_ind=None, keep_difficult=False): # 將類別名字轉換成數字label self.class_to_ind = class_to_ind or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES)))) # 在xml裏面,有個difficult的參數,這個表示特別難識別的目標,通常是小目標或者遮擋嚴重的目標 # 所以,能夠經過這個參數,忽略這些目標 self.keep_difficult = keep_difficult def __call__(self, target, width, height): ''' 將一張圖裏麪包含若干個目標,獲取這些目標的座標值,並轉換成0到1,並獲得其label :param target: xml格式 :return: 返回List,每一個目標對應一行,每行包括5個參數[xmin, ymin, xmax, ymax, label_ind] ''' res = [] for obj in target.iter('object'): difficult = int(obj.find('difficult').text) == 1 # 判斷該目標是否爲難例 # 判斷是否跳過難例 if not self.keep_difficult and difficult: continue name = obj.find('name').text.lower().strip() # text是得到目標的名稱,lower將字符轉換成小寫,strip去除先後空格 bbox = obj.find('bndbox') # 得到真實框座標 pts = ['xmin', 'ymin', 'xmax', 'ymax'] bndbox = [] for i, pt in enumerate(pts): cur_pt = int(bbox.find(pt).text) - 1 # 得到座標值 # 將座標轉換成[0,1],這樣圖片尺寸發生變化的時候,真實框也隨之變化,即平移不變形 cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height bndbox.append(cur_pt) label_idx = self.class_to_ind[name] # 得到名字對應的label bndbox.append(label_idx) res.append(bndbox) # [xmin, ymin, xmax, ymax, label_ind] return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
還有一個重要的函數,即函數SSDAugmentation(),該函數的做用是做數據加強。論文中也說起了,數據加強對最終的結果提高有着重大做用。博客1和博客2具體講述了數據加強的源碼,講得十分詳細。在本項目中,SSDAugmentation()函數在data/augmentations.py中,以下所示。因爲opencv讀取讀片的時候,取值範圍是[0,255],是int類型,須要將其轉換爲float類型,計算其GT框的正式座標。而後對圖片進行光度變形,包含改變對比度,改變飽和度,改變色調、改變亮度和增長噪聲等。接着有對圖片進行擴張和裁剪等。在此操做中,會涉及到GT框座標的變換。最後,當上述變化處理完後,再對GT框座標歸一化,和resize圖片,減去均值等。具體細節,能夠參考兩篇博客進行解讀。
class SSDAugmentation(object): def __init__(self, size=300, mean=(104, 117, 123)): self.mean = mean self.size = size self.augment = Compose([ ConvertFromInts(), # 將圖片從int轉換成float ToAbsoluteCoords(), # 計算真實的錨點框座標 PhotometricDistort(), # 光度變形 Expand(self.mean), # 隨機擴張圖片 RandomSampleCrop(), # 隨機裁剪 RandomMirror(), # 隨機鏡像 ToPercentCoords(), Resize(self.size), SubtractMeans(self.mean) ]) def __call__(self, img, boxes, labels): return self.augment(img, boxes, labels)
在一個batch中,每張圖片的GT框數量是不等的,所以,須要定義一個函數來處理這種狀況。函數detection_collate()就是用於處理這種狀況,使得一張圖片能對應一個list,這裏list裏面有全部GT框的信息組成。
def detection_collate(batch): """Custom collate fn for dealing with batches of images that have a different number of associated object annotations (bounding boxes). 自定義處理在同一個batch,含有不一樣數量的目標框的狀況 Arguments: batch: (tuple) A tuple of tensor images and lists of annotations Return: A tuple containing: 1) (tensor) batch of images stacked on their 0 dim 2) (list of tensors) annotations for a given image are stacked on 0 dim """ targets = [] imgs = [] for sample in batch: imgs.append(sample[0]) targets.append(torch.FloatTensor(sample[1])) return torch.stack(imgs, 0), targets
至此,已經將SSD的數據讀取部分分析完。