Enable find_unused_parameters for DistributedDataParallel

attention_norm has some parameters which are not used to compute grad,
which is causing failures in the distributed case.
This commit is contained in:
James Betker 2020-07-23 09:08:13 -06:00
parent dbf6147504
commit bba283776c

View File

@ -136,13 +136,15 @@ class SRGANModel(BaseModel):
# DataParallel
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
else:
self.netG = DataParallel(self.netG)
if self.is_train:
if opt['dist']:
self.netD = DistributedDataParallel(self.netD,
device_ids=[torch.cuda.current_device()])
device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
else:
self.netD = DataParallel(self.netD)
self.netG.train()