pytorch :: Dataloader中的迭代器和生成器應用

在使用pytorch訓練模型,常常須要加載大量圖片數據,所以pytorch提供了好用的數據加載工具Dataloader。
爲了實現小批量循環讀取大型數據集,在Dataloader類具體實現中,使用了迭代器和生成器。
這一應用場景正是python中迭代器模式的意義所在,所以本文對Dataloader中代碼進行解讀,能夠更好的理解python中迭代器和生成器的概念。html

本文的內容主要有:python

  1. 解釋python中的迭代器和生成器概念
  2. 解讀pytorch中Dataloader代碼,如何使用迭代器和生成器實現數據加載

python迭代基礎

python中圍繞着迭代有如下概念:git

  1. 可迭代對象 iterables
  2. 迭代器 iterator
  3. 生成器 generator

這三個概念互相關聯,並非孤立的。在可迭代對象的基礎上發展了迭代器,在迭代器的基礎上又發展了生成器。
學習這些概念的名詞解釋沒有多大意義。編程中不少的抽象概念都是爲了更好的實現某些功能,纔去人爲創造的協議和模式。
所以,要理解它們,須要探究概念背後的邏輯,爲何這樣設計?要解決的真正問題是什麼?在哪些場景下應用是最好的?github

迭代模式首先要解決的基礎問題是,須要按必定順序獲取集合內部數據,好比循環某個list。
當數據很小時,不會有問題。但當讀取大量數據時,一次性讀取會超出內存限制,所以想出如下方法:編程

  • 把大的數據分紅幾個小塊,分批處理
  • 惰性的取值方式,按需取值

循環讀數據可分爲下面三種應用場景,對應着容器(可迭代對象),迭代器和生成器:app

  1. for x in container: 爲了遍歷python內部序列容器(如list), 這些類型內部實現了__getitem__() 方法,能夠從0開始按順序遍歷序列容器中的元素。
  2. for x in iterator: 爲了循環用戶自定義的迭代器,須要實現__iter__和__next__方法,__iter__是迭代協議,具體每次迭代的執行邏輯在 __next__或next方法裏
  3. for x in generator: 爲了節省循環的內存和加速,使用生成器來實現惰性加載,在迭代器的基礎上加入了yield語句,最簡單的例子是 range(5)

代碼示例:dom

# 普通循環 for x in list
numbers = [1, 2, 3,]
for n in numbers:
    print(n) # 1,2,3

# for循環實際乾的事情
# iter輸入一個可迭代對象list,返回迭代器
# next方法取數據
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception

# 迭代器循環 for x in iterator
for i,n in enumerate(numbers):
    print(i,n) # 0,1 / 1,3 / 2,3

# 生成器循環 for x in generator
for i in range(3):
    print(i) # 0,1,2

上面示例代碼中python內置函數iter和next的用法:ide

  • iter函數,調用__iter__,返回一個迭代器
  • next函數,輸入迭代器,調用__next__,取出數據

比較容易混淆的是__iter__和__next__兩個方法。它們的區別是:函數

  1. __iter__是爲了能夠迭代,真正執行取數據的邏輯是__next__方法實現的,實際調用是經過next(iterator)完成
  2. __iter__能夠返回自身(return self),實際讀取數據的實現放在__next__方法
  3. __iter__能夠和yield搭配,返回生成器對象

__iter__返回自身的作法有點相似 python中的類型系統。爲了保持一致性,python中一切皆對象。
每一個對象建立後,都有類型指針,而類型對象的指針指向元對象,元對象的指針指向自身。工具

生成器,是在__iter__方法中加入yield語句,好處有:

  1. 減小循環判斷邏輯的複雜度
  2. 惰性取值,節省內存和時間

yield做用:

  1. 代替函數中的return語句
  2. 記住上一次循環迭代器內部元素的位置

三種循環模式經常使用函數

for x in container方法:

  • list, deque, …
  • set, frozensets, …
  • dict, defaultdict, OrderedDict, Counter, …
  • tuple, namedtuple, …
  • str

for x in iterator方法:

  • enumerate() # 加上list的index
  • sorted() # 排序list
  • reversed() # 倒序list
  • zip() # 合併list

for x in generator方法:

  • range()
  • map()
  • filter()
  • reduce()
  • [x for x in list(...)]

Dataloder源碼分析

pytorch採用for x in iterator模式,從Dataloader類中讀取數據。

  1. 爲了實現該迭代模式,在Dataloader內部實現__iter__方法,實際返回的是_DataLoaderIter類。
  2. _DataLoaderIter類裏面,實現了 __iter__方法,返回自身,具體執行讀數據的邏輯,在__next__方法中。

如下代碼只截取了單線程下的數據讀取。

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.
    """
    def __init__(self, dataset, batch_size=1, shuffle=False, ...):
        self.dataset = dataset
        self.batch_sampler = batch_sampler
        ...
    
    def __iter__(self):
        return _DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

class _DataLoaderIter(object):
    r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
    def __init__(self, loader):
        self.sample_iter = iter(self.batch_sampler)
        ...

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch
        ...

    def __iter__(self):
        return self

Dataloader類中讀取數據Index的方法,採用了 for x in generator方式,可是調用採用iter和next函數

  1. 構建隨機採樣類RandomSampler,內部實現了 __iter__方法
  2. __iter__方法內部使用了 yield,循環遍歷數據集,當數量達到batch_size大小時,就返回
  3. 實例化隨機採樣類,傳入iter函數,返回一個迭代器
  4. next會調用隨機採樣類中生成器,返回相應的index數據
class RandomSampler(object):
    """random sampler to yield a mini-batch of indices."""
    def __init__(self, batch_size, dataset, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_imgs = len(dataset)
        self.drop_last = drop_last

    def __iter__(self):
        indices = np.random.permutation(self.num_imgs)
        batch = []
        for i in indices:
            batch.append(i)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        ## if images not to yield a batch
        if len(batch)>0 and not self.drop_last:
            yield batch


    def __len__(self):
        if self.drop_last:
            return self.num_imgs // self.batch_size
        else:
            return (self.num_imgs + self.batch_size - 1) // self.batch_size

batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)

總結

本文總結了python中循環的三種模式:

  1. for x in container 可迭代對象
  2. for x in iterator 迭代器
  3. for x in generator 生成器

pytorch中的數據加載模塊 Dataloader,使用生成器來返回數據的索引,使用迭代器來返回須要的張量數據,能夠在大量數據狀況下,實現小批量循環迭代式的讀取,避免了內存不足問題。

參考文章

相關文章
相關標籤/搜索