From 96586d6592396e6505e0998ea1b5f0ae6eaaa6b5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 25 Aug 2020 16:06:27 -0600 Subject: [PATCH] Fix distributed d_grad --- codes/models/SRGAN_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 16c455cd..1738e902 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -288,8 +288,14 @@ class SRGANModel(BaseModel): self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) + if self.spsr_enabled: + self.netD_grad = DistributedDataParallel(self.netD_grad, + device_ids=[torch.cuda.current_device()], + find_unused_parameters=True) else: self.netD = DataParallel(self.netD) + if self.spsr_enabled: + self.netD_grad = DataParallel(self.netD_grad) self.netG.train() self.netD.train() if self.spsr_enabled: