diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index a1b2d5d4..411cc13a 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -136,13 +136,15 @@ class SRGANModel(BaseModel): # DataParallel if opt['dist']: - self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()], + find_unused_parameters=True) else: self.netG = DataParallel(self.netG) if self.is_train: if opt['dist']: self.netD = DistributedDataParallel(self.netD, - device_ids=[torch.cuda.current_device()]) + device_ids=[torch.cuda.current_device()], + find_unused_parameters=True) else: self.netD = DataParallel(self.netD) self.netG.train()