Distribute get_grad_no_padding

This commit is contained in:
James Betker 2020-08-25 17:03:18 -06:00
parent 2f706b7d93
commit 53e67bdb9c

View File

@ -265,9 +265,8 @@ class SRGANModel(BaseModel):
self.disc_optimizers.append(self.optimizer_D_grad) self.disc_optimizers.append(self.optimizer_D_grad)
if self.spsr_enabled: if self.spsr_enabled:
self.get_grad = ImageGradient().to(self.device)
self.get_grad_nopadding = ImageGradientNoPadding().to(self.device) self.get_grad_nopadding = ImageGradientNoPadding().to(self.device)
[self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], \ [self.netG, self.netD, self.netD_grad, self.get_grad_nopadding], \
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \ [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding],
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad], [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
@ -292,6 +291,9 @@ class SRGANModel(BaseModel):
self.netD_grad = DistributedDataParallel(self.netD_grad, self.netD_grad = DistributedDataParallel(self.netD_grad,
device_ids=[torch.cuda.current_device()], device_ids=[torch.cuda.current_device()],
find_unused_parameters=True) find_unused_parameters=True)
self.get_grad_nopadding = DistributedDataParallel(self.get_grad_nopadding,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
else: else:
self.netD = DataParallel(self.netD) self.netD = DataParallel(self.netD)
if self.spsr_enabled: if self.spsr_enabled: