一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關係

如下內容都是針對Pytorch 1.0-1.1介紹。
不少文章都是從Dataset等對象自下往上進行介紹,可是對於初學者而言,其實這並很差理解,由於有的時候會不自覺地陷入到一些細枝末節中去,而不能把握重點,因此本文將會自上而下地對Pytorch數據讀取方法進行介紹。python

自上而下理解三者關係

首先咱們看一下DataLoader.__next__的源代碼長什麼樣,爲方便理解我只選取了num_works爲0的狀況(num_works簡單理解就是可以並行化地讀取數據)。git

class DataLoader(object):
    ...
    
    def __next__(self):
        if self.num_workers == 0:  
            indices = next(self.sample_iter)  # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

在閱讀上面代碼前,咱們能夠假設咱們的數據是一組圖像,每一張圖像對應一個index,那麼若是咱們要讀取數據就只須要對應的index便可,即上面代碼中的indices,而選取index的方式有多種,有按順序的,也有亂序的,因此這個工做須要Sampler完成,如今你不須要具體的細節,後面會介紹,你只須要知道DataLoader和Sampler在這裏產生關係。github

那麼Dataset和DataLoader在何時產生關係呢?沒錯就是下面一行。咱們已經拿到了indices,那麼下一步咱們只須要根據index對數據進行讀取便可了。dom

再下面的if語句的做用簡單理解就是,若是pin_memory=True,那麼Pytorch會採起一系列操做把數據拷貝到GPU,總之就是爲了加速。ide

綜上能夠知道DataLoader,Sampler和Dataset三者關係以下:
函數

在閱讀後文的過程當中,你始終須要將上面的關係記在內心,這樣能幫助你更好地理解。ui

Sampler

參數傳遞

要更加細緻地理解Sampler原理,咱們須要先閱讀一下DataLoader 的源代碼,以下:this

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

能夠看到初始化參數裏有兩種sampler:samplerbatch_sampler,都默認爲None。前者的做用是生成一系列的index,而batch_sampler則是將sampler生成的indices打包分組,獲得一個又一個batch的index。例以下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分組。spa

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch中已經實現的Sampler有以下幾種:3d

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

須要注意的是DataLoader的部分初始化參數之間存在互斥關係,這個你能夠經過閱讀源碼更深地理解,這裏只作總結:

  • 若是你自定義了batch_sampler,那麼這些參數都必須使用默認值:batch_size, shuffle,sampler,drop_last.
  • 若是你自定義了sampler,那麼shuffle須要設置爲False
  • 若是samplerbatch_sampler都爲None,那麼batch_sampler使用Pytorch已經實現好的BatchSampler,而sampler分兩種狀況:
    • shuffle=True,則sampler=RandomSampler(dataset)
    • shuffle=False,則sampler=SequentialSampler(dataset)

如何自定義Sampler和BatchSampler?

仔細查看源代碼其實能夠發現,全部採樣器其實都繼承自同一個父類,即Sampler,其代碼定義以下:

class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError
        
    def __len__(self):
        return len(self.data_source)

因此你要作的就是定義好__iter__(self)函數,不過要注意的是該函數的返回值須要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))

另外BatchSampler與其餘Sampler的主要區別是它須要將Sampler做爲參數進行打包,進而每次迭代返回以batch size爲大小的index列表。也就是說在後面的讀取數據過程當中使用的都是batch sampler。

Dataset

Dataset定義方式以下:

class Dataset(object):
    def __init__(self):
        ...
        
    def __getitem__(self, index):
        return ...
    
    def __len__(self):
        return ...

上面三個方法是最基本的,其中__getitem__是最主要的方法,它規定了如何讀取數據。可是它又不一樣於通常的方法,由於它是python built-in方法,其主要做用是能讓該類能夠像list同樣經過索引值對數據進行訪問。假如你定義好了一個dataset,那麼你能夠直接經過dataset[0]來訪問第一個數據。在此以前我一直沒弄清楚__getitem__是什麼做用,因此一直不知道該怎麼進入到這個函數進行調試。如今若是你想對__getitem__方法進行調試,你能夠寫一個for循環遍歷dataset來進行調試了,而不用構建dataloader等一大堆東西了,建議學會使用ipdb這個庫,很是實用!!!之後有時間再寫一篇ipdb的使用教程。另外,其實咱們經過最前面的Dataloader的__next__函數能夠看到DataLoader對數據的讀取其實就是用了for循環來遍歷數據,不用往上翻了,我直接複製了一遍,以下:

class DataLoader(object): 
    ... 
     
    def __next__(self): 
        if self.num_workers == 0:   
            indices = next(self.sample_iter)  
            batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
            if self.pin_memory: 
                batch = _utils.pin_memory.pin_memory_batch(batch) 
            return batch

咱們仔細看能夠發現,前面還有一個self.collate_fn方法,這個是幹嗎用的呢?在介紹前咱們須要知道每一個參數的意義:

  • indices: 表示每個iteration,sampler返回的indices,即一個batch size大小的索引列表
  • self.dataset[i]: 前面已經介紹了,這裏就是對第i個數據進行讀取操做,通常來講self.dataset[i]=(img, label)

看到這不難猜出collate_fn的做用就是將一個batch的數據進行合併操做。默認的collate_fn是將img和label分別合併成imgs和labels,因此若是你的__getitem__方法只是返回 img, label,那麼你可使用默認的collate_fn方法,可是若是你每次讀取的數據有img, box, label等等,那麼你就須要自定義collate_fn來將對應的數據合併成一個batch數據,這樣方便後續的訓練步驟。



MARSGGBO原創




2019-8-6

相關文章
相關標籤/搜索