以前對Pytorch 1.0 的Dataparallel的使用方法一直似懂非懂,老是會碰到各類莫名其妙的問題,今天就好好從源頭梳理一下,更好地理解它的原理或者說說下步驟。python
源碼地址: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.pygit
首先咱們一行一行地來看一下Dataparallel是如何初始化的。github
super
就是繼承torch.nn.Module父類,這裏不作解釋output_device
表示輸出到哪個GPU上,默認是第一個GPU,注意這個第一個是device_ids列表上的第一個,因此若是你有三個GPU,而你在將model複製到cuda上時寫的代碼是model.cuda(1)
或者model.cuda(2)
,則會報錯,由於device_ids
是[0,1,2].其第一個元素是0。這一點能夠在後面的forward
函數中看到。def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() if not torch.cuda.is_available(): self.module = module self.device_ids = [] return if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0])) _check_balance(self.device_ids) if len(self.device_ids) == 1: self.module.cuda(device_ids[0])
下面進入到重頭戲:Dataparallel的forward函數。app
def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError("module must have its parameters and buffers " "on device {} (device_ids[0]) but found one of " "them on device: {}".format(self.src_device_obj, t.device)) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) return self.gather(outputs, self.output_device)
scatter
函數def scatter(inputs, target_gpus, dim=0): r""" Slices tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: res = scatter_map(inputs) finally: scatter_map = None return res
replica
函數,這個函數比較複雜,就不解釋了,感興趣的能夠閱讀一下源碼:https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/replicate.py 。不過它的主要做用就是將模型複製到多個GPU上。parallel_apply
做用就是並行地在多個GPU上計算模型,每一個模型是同樣的,只不過輸入數據是不同的,由於前面將數據平均劃分了。例如你有兩個GPU,一個batch大小是64,那麼兩個GPU分別處理batch大小爲32的數據。gather
到一塊兒,傳送到output_device
,即第一個GPU設備上。