PyTorch 報錯:ModuleAttributeError: ‘DataParallel‘ object has no attribute ‘ xxx (已解決)

PyTorch 報錯:ModuleAttributeError: 'DataParallel' object has no attribute ' xxx (已解決)

 

這個問題中 ,‘XXX’ 通常就是代碼裏面的須要優化的模型名稱,例如,個人模型裏定義了 optimizer_G 和 optimizer_D 兩個網絡(生成器網絡和判別器網絡)。python

問題緣由:

在 train.py 中,調用它們時,直覺地寫成了 model.optimizer_G 的格式,以下:網絡

model = create_model(opt)
model = model.cuda()
visualizer = Visualizer(opt)
if opt.fp16:    
    model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1')             
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
    optimizer_G, optimizer_Dh = model.optimizer_G, model.optimizer_D

然而,其實這時 model 轉換成了 model.module。優化

 

解決方法:

在 ‘ model. ’ 後面加一個 ‘ module. ’ 。spa

將 model.optimizer_G 改爲 model.module.optimizer_Gcode

將 model.optimizer_D 改爲 model.module.optimizer_Dit

model = create_model(opt)
model = model.cuda()
visualizer = Visualizer(opt)
if opt.fp16:    
    model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.module.optimizer_G, model.module.optimizer_D], opt_level='O1')             
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
    optimizer_G, optimizer_Dh = model.module.optimizer_G, model.module.optimizer_D
相關文章
相關標籤/搜索