diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 7016c257..47226a52 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -265,9 +265,8 @@ class SRGANModel(BaseModel): self.disc_optimizers.append(self.optimizer_D_grad) if self.spsr_enabled: - self.get_grad = ImageGradient().to(self.device) self.get_grad_nopadding = ImageGradientNoPadding().to(self.device) - [self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], \ + [self.netG, self.netD, self.netD_grad, self.get_grad_nopadding], \ [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \ amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad], @@ -292,6 +291,9 @@ class SRGANModel(BaseModel): self.netD_grad = DistributedDataParallel(self.netD_grad, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) + self.get_grad_nopadding = DistributedDataParallel(self.get_grad_nopadding, + device_ids=[torch.cuda.current_device()], + find_unused_parameters=True) else: self.netD = DataParallel(self.netD) if self.spsr_enabled: