在使用pytorch訓練模型,常常須要加載大量圖片數據,所以pytorch提供了好用的數據加載工具Dataloader。
爲了實現小批量循環讀取大型數據集,在Dataloader類具體實現中,使用了迭代器和生成器。
這一應用場景正是python中迭代器模式的意義所在,所以本文對Dataloader中代碼進行解讀,能夠更好的理解python中迭代器和生成器的概念。html
本文的內容主要有:python
python中圍繞着迭代有如下概念:git
這三個概念互相關聯,並非孤立的。在可迭代對象的基礎上發展了迭代器,在迭代器的基礎上又發展了生成器。
學習這些概念的名詞解釋沒有多大意義。編程中不少的抽象概念都是爲了更好的實現某些功能,纔去人爲創造的協議和模式。
所以,要理解它們,須要探究概念背後的邏輯,爲何這樣設計?要解決的真正問題是什麼?在哪些場景下應用是最好的?github
迭代模式首先要解決的基礎問題是,須要按必定順序獲取集合內部數據,好比循環某個list。
當數據很小時,不會有問題。但當讀取大量數據時,一次性讀取會超出內存限制,所以想出如下方法:編程
循環讀數據可分爲下面三種應用場景,對應着容器(可迭代對象),迭代器和生成器:app
for x in container
: 爲了遍歷python內部序列容器(如list), 這些類型內部實現了__getitem__() 方法,能夠從0開始按順序遍歷序列容器中的元素。for x in iterator
: 爲了循環用戶自定義的迭代器,須要實現__iter__和__next__方法,__iter__是迭代協議,具體每次迭代的執行邏輯在 __next__或next方法裏for x in generator
: 爲了節省循環的內存和加速,使用生成器來實現惰性加載,在迭代器的基礎上加入了yield語句,最簡單的例子是 range(5)代碼示例:dom
# 普通循環 for x in list numbers = [1, 2, 3,] for n in numbers: print(n) # 1,2,3 # for循環實際乾的事情 # iter輸入一個可迭代對象list,返回迭代器 # next方法取數據 my_iterator = iter(numbers) next(my_iterator) # 1 next(my_iterator) # 2 next(my_iterator) # 3 next(my_iterator) # StopIteration exception # 迭代器循環 for x in iterator for i,n in enumerate(numbers): print(i,n) # 0,1 / 1,3 / 2,3 # 生成器循環 for x in generator for i in range(3): print(i) # 0,1,2
上面示例代碼中python內置函數iter和next的用法:ide
比較容易混淆的是__iter__和__next__兩個方法。它們的區別是:函數
__iter__返回自身的作法有點相似 python中的類型系統。爲了保持一致性,python中一切皆對象。
每一個對象建立後,都有類型指針,而類型對象的指針指向元對象,元對象的指針指向自身。工具
生成器,是在__iter__方法中加入yield語句,好處有:
yield做用:
for x in container
方法:
list, deque, …
set, frozensets, …
dict, defaultdict, OrderedDict, Counter, …
tuple, namedtuple, …
str
for x in iterator
方法:
enumerate()
# 加上list的indexsorted()
# 排序listreversed()
# 倒序listzip()
# 合併listfor x in generator
方法:
range()
map()
filter()
reduce()
[x for x in list(...)]
pytorch採用for x in iterator
模式,從Dataloader類中讀取數據。
如下代碼只截取了單線程下的數據讀取。
class DataLoader(object): r""" Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. """ def __init__(self, dataset, batch_size=1, shuffle=False, ...): self.dataset = dataset self.batch_sampler = batch_sampler ... def __iter__(self): return _DataLoaderIter(self) def __len__(self): return len(self.batch_sampler) class _DataLoaderIter(object): r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader): self.sample_iter = iter(self.batch_sampler) ... def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch ... def __iter__(self): return self
Dataloader類中讀取數據Index的方法,採用了 for x in generator
方式,可是調用採用iter和next函數
class RandomSampler(object): """random sampler to yield a mini-batch of indices.""" def __init__(self, batch_size, dataset, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.num_imgs = len(dataset) self.drop_last = drop_last def __iter__(self): indices = np.random.permutation(self.num_imgs) batch = [] for i in indices: batch.append(i) if len(batch) == self.batch_size: yield batch batch = [] ## if images not to yield a batch if len(batch)>0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return self.num_imgs // self.batch_size else: return (self.num_imgs + self.batch_size - 1) // self.batch_size batch_sampler = RandomSampler(batch_size. dataset) sample_iter = iter(batch_sampler) indices = next(sample_iter)
本文總結了python中循環的三種模式:
for x in container
可迭代對象for x in iterator
迭代器for x in generator
生成器pytorch中的數據加載模塊 Dataloader,使用生成器來返回數據的索引,使用迭代器來返回須要的張量數據,能夠在大量數據狀況下,實現小批量循環迭代式的讀取,避免了內存不足問題。