Enable find_unused_parameters for DistributedDataParallel
attention_norm has some parameters which are not used to compute grad, which is causing failures in the distributed case.
This commit is contained in:
parent
dbf6147504
commit
bba283776c
|
@ -136,13 +136,15 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
# DataParallel
|
# DataParallel
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()],
|
||||||
|
find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
self.netG = DataParallel(self.netG)
|
self.netG = DataParallel(self.netG)
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.netD = DistributedDataParallel(self.netD,
|
self.netD = DistributedDataParallel(self.netD,
|
||||||
device_ids=[torch.cuda.current_device()])
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
self.netD = DataParallel(self.netD)
|
self.netD = DataParallel(self.netD)
|
||||||
self.netG.train()
|
self.netG.train()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user