forked from mrq/DL-Art-School
Distribute get_grad_no_padding
This commit is contained in:
parent
2f706b7d93
commit
53e67bdb9c
|
@ -265,9 +265,8 @@ class SRGANModel(BaseModel):
|
|||
self.disc_optimizers.append(self.optimizer_D_grad)
|
||||
|
||||
if self.spsr_enabled:
|
||||
self.get_grad = ImageGradient().to(self.device)
|
||||
self.get_grad_nopadding = ImageGradientNoPadding().to(self.device)
|
||||
[self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], \
|
||||
[self.netG, self.netD, self.netD_grad, self.get_grad_nopadding], \
|
||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
|
||||
amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding],
|
||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
|
||||
|
@ -292,6 +291,9 @@ class SRGANModel(BaseModel):
|
|||
self.netD_grad = DistributedDataParallel(self.netD_grad,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=True)
|
||||
self.get_grad_nopadding = DistributedDataParallel(self.get_grad_nopadding,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=True)
|
||||
else:
|
||||
self.netD = DataParallel(self.netD)
|
||||
if self.spsr_enabled:
|
||||
|
|
Loading…
Reference in New Issue
Block a user