Detectron2源碼閱讀筆記-(三)Dataset pipeline

構建data_loader原理步驟

# engine/default.py
from detectron2.data import (
    MetadataCatalog,
    build_detection_test_loader,
    build_detection_train_loader,
)
class DefaultTrainer(SimpleTrainer):
    def __init__(self, cfg):
        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        ...    
    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        """
        return build_detection_train_loader(cfg)

函數調用關係以下圖:html

結合前面兩篇文章的內容能夠看到detectron2在構建model,optimizer和data_loader的時候都是在對應的build.py文件裏實現的。咱們看一下build_detection_train_loader是如何定義的(對應上圖中紫色方框內的部分(自下往上的順序)):python

def build_detection_train_loader(cfg, mapper=None):
    """
    A data loader is created by the following steps:

    1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
    2. Start workers to work on the dicts. Each worker will:
      * Map each metadata dict into another format to be consumed by the model.
      * Batch them by simply putting dicts into a list.
    The batched ``list[mapped_dict]`` is what this dataloader will return.

    Args:
        cfg (CfgNode): the config
        mapper (callable): a callable which takes a sample (dict) from dataset and
            returns the format to be consumed by the model.
            By default it will be `DatasetMapper(cfg, True)`.

    Returns:
        a torch DataLoader object
    """
    # 得到dataset_dicts
    dataset_dicts = get_detection_dataset_dicts(
        cfg.DATASETS.TRAIN,
        filter_empty=True,
        min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
        if cfg.MODEL.KEYPOINT_ON
        else 0,
        proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
    )
    
    # 將dataset_dicts轉化成torch.utils.data.Dataset
    dataset = DatasetFromList(dataset_dicts, copy=False)

    # 進一步轉化成MapDataset,每次讀取數據時都會調用mapper來對dict進行解析
    if mapper is None:
        mapper = DatasetMapper(cfg, True)
    dataset = MapDataset(dataset, mapper)
    
    # 採樣器
    sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
    if sampler_name == "TrainingSampler":
        sampler = samplers.TrainingSampler(len(dataset))
        ...
    batch_sampler = build_batch_data_sampler(
        sampler, images_per_worker, group_bin_edges, aspect_ratios
    )
    
    # 數據迭代器 data_loader
    data_loader = torch.utils.data.DataLoader(
        dataset,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        batch_sampler=batch_sampler,
        collate_fn=trivial_batch_collator,
        worker_init_fn=worker_init_reset_seed,
    )
    return data_loader

由上面的源代碼能夠看出總共是五個步驟,咱們只對前面三個部分進行詳細介紹,後面的採樣器和data_loader能夠參閱一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關係ios

得到dataset_dicts

get_detection_dataset_dicts(dataset_names)函數須要傳遞的一個重要參數是dataset_names,這個參數其實就是一個字符串,用來指定數據集的名稱。經過這個字符串,該函數會調用data/catalog.pyDatasetCatalog類來進行解析獲得一個包含數據信息的字典。json

解析的原理是:DatasetCatalog有一個字典_REGISTERED,默認已經註冊好了例如coco,voc這些數據集的信息。若是你想要使用你本身的數據集,那麼你須要在最開始前你須要定義你的數據集名字以及定義一個函數(這個函數不須要傳參,並且最後會返回一個dict,該dict包含你的數據集信息),舉個栗子:app

from detectron2.data import DatasetCatalog
my_dataset_name = 'apple'
def get_dicts():
    ...
    return dict

DatasetCatalog.register(my_dataset_name, get_dicts)

固然,若是你的數據集已是COCO的格式了,那麼你也可使用以下方法進行註冊:dom

from detectron2.data.datasets import register_coco_instances
my_dataset_name = 'apple'
register_coco_instances(my_dataset_name, {}, "json_annotation.json", "path/to/image/dir")

另外須要注意的是一個數據集實際上是能夠由兩個類來定義的,一個是前面介紹了的DatasetCatalog,另外一個是MetadataCatalog函數

MetadataCatalog的做用是記錄數據集的一些特徵,這樣咱們就能夠很方便的在整個代碼中獲取數據集的特徵信息。在註冊DatasetCatalog後,咱們能夠按以下栗子對MetadataCatalog進行註冊並定義咱們後面可能會用到的屬性特徵:ui

from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"]

# 也能夠這樣
MetadataCatalog.get("my_dataset").set("thing_classes",["person", "dog"])

注意:若是你的數據集名字未註冊過,MetadataCatalog.get會自動進行註冊,而後會自動設置你所設定的屬性值。this

其實MetadataCatalog還有其餘的特徵屬性能夠設置,如stuff_classes,stuff_colors等等。你可能會好奇thing_classesstuff_classes有什麼區別,區別以下:spa

  • 抽象解釋:thing_classes用於指定instance-level任務,stuff_classes用於semantic segmentation任務。
  • 具體解釋:像椅子,書這種可數的東西,就能夠理解成thing,因此用於instance-level;而雪、天空這種不可數的就理解成stuff,因此用於semantic segmentation。參考On Seeing Stuff: The Perception of Materials by Humans and Machines

最後,get_detection_dataset_dicts會返回一個包含若干個dict的list,之因此是list是由於參數dataset_names也是一個list,這樣咱們就能夠制定多個names來同時對數據進行讀取。

解析成DatasetFromList

DatasetFromList(dataset_dict)函數定義在detectron2/data/common.py中,它其實就是一個torch.utils.data.Dataset類,其源碼以下

class DatasetFromList(data.Dataset):
    """
    Wrap a list to a torch Dataset. It produces elements of the list as data.
    """

    def __init__(self, lst: list, copy: bool = True):
        """
        Args:
            lst (list): a list which contains elements to produce.
            copy (bool): whether to deepcopy the element when producing it,
                so that the result can be modified in place without affecting the
                source in the list.
        """
        self._lst = lst
        self._copy = copy

    def __len__(self):
        return len(self._lst)

    def __getitem__(self, idx):
        if self._copy:
            return copy.deepcopy(self._lst[idx])
        else:
            return self._lst[idx]

這個很簡單就不加贅述了

DatsetFromList轉化成MapDataset

其實DatsetFromListMapDataset都是torch.utils.data.Dataset的子類,那他們的區別是什麼呢?很簡單,區別就是後者使用了mapper

在解釋mapper是什麼以前咱們首先要知道的是,在detectron2中,一張圖片對應的是一個dict,那麼整個數據集就是list[dict]。以後咱們再看DatsetFromList,它的__getitem__函數很是簡單,它只是簡單粗暴地就返回了指定idx的元素。顯然這樣是不行的,由於在把數據扔給模型訓練以前咱們確定還要對數據作必定的處理,而這個工做就是由mapper來作的,默認狀況下使用的是detectron2/data/dataset_mapper.py中定義的DatasetMapper,若是你須要自定義一個mapper也能夠參考這個寫。

DatasetMapper(cfg, is_train=True)

咱們繼續瞭解一下DatasetMapper的實現原理,首先看一下官方給的定義:

A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.

簡單歸納就是這個類是可調用的(callable),因此在下面的源碼中能夠看到定義了__call__方法。

該類主要作了這三件事:

The callable currently does the following:

  1. Read the image from "file_name"
  2. Applies cropping/geometric transforms to the image and annotations
  3. Prepare data and annotations to Tensor and :class:Instances

其源碼以下(有刪減):

class DatasetMapper:
    def __init__(self, cfg, is_train=True):
        # 讀取cfg的參數
        ...

    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        
        # 1. 讀取圖像數據
        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
        
        # 2. 對image和box等作Transformation
        if "annotations" not in dataset_dict:
            image, transforms = T.apply_transform_gens(
                ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
            )
        else:
            ...
            image, transforms = T.apply_transform_gens(self.tfm_gens, image)
            if self.crop_gen:
                transforms = crop_tfm + transforms

        image_shape = image.shape[:2]  # h, w
        
        # 3.將數據轉化成tensor格式
        dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
        ...

        return dataset_dict

MapDataset

class MapDataset(data.Dataset):
    def __init__(self, dataset, map_func):
        self._dataset = dataset
        self._map_func = PicklableWrapper(map_func)  # wrap so that a lambda will work

        self._rng = random.Random(42)
        self._fallback_candidates = set(range(len(dataset)))

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        retry_count = 0
        cur_idx = int(idx)

        while True:
            data = self._map_func(self._dataset[cur_idx])
            if data is not None:
                self._fallback_candidates.add(cur_idx)
                return data

            # _map_func fails for this idx, use a random new index from the pool
            retry_count += 1
            self._fallback_candidates.discard(cur_idx)
            cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]

            if retry_count >= 3:
                logger = logging.getLogger(__name__)
                logger.warning(
                    "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
                        idx, retry_count
                    )
                )
  • self._fallback_candidates是一個set,它的特色是其中的元素是獨一無二的,定義這個的做用是記錄可正常讀取的數據索引,由於有的數據可能沒法正常讀取,因此這個時候咱們就能夠把這個壞數據的索引從_fallback_candidates中剔除,並隨機採樣一個索引來讀取數據。
  • __getitem__中的邏輯就是首先讀取指定索引的數據,若是正常讀取就把該所索引值加入到_fallback_candidates中去;反之,若是數據沒法讀取,則將對應索引值刪除,並隨機採樣一個數據,而且嘗試3次,若3次後都沒法正常讀取數據,則報錯,可是好像也沒有退出程序,而是繼續讀數據,多是覺得總有能正常讀取的數據吧hhh。


MARSGGBO原創

若有意合做,歡迎私戳

郵箱:marsggbo@foxmail.com

2019-10-23 13:37:13

相關文章
相關標籤/搜索