forked from mrq/DL-Art-School
Distribute get_grad_no_padding
This commit is contained in:
parent
2f706b7d93
commit
53e67bdb9c
|
@ -265,9 +265,8 @@ class SRGANModel(BaseModel):
|
||||||
self.disc_optimizers.append(self.optimizer_D_grad)
|
self.disc_optimizers.append(self.optimizer_D_grad)
|
||||||
|
|
||||||
if self.spsr_enabled:
|
if self.spsr_enabled:
|
||||||
self.get_grad = ImageGradient().to(self.device)
|
|
||||||
self.get_grad_nopadding = ImageGradientNoPadding().to(self.device)
|
self.get_grad_nopadding = ImageGradientNoPadding().to(self.device)
|
||||||
[self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], \
|
[self.netG, self.netD, self.netD_grad, self.get_grad_nopadding], \
|
||||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
|
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
|
||||||
amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding],
|
amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding],
|
||||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
|
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
|
||||||
|
@ -292,6 +291,9 @@ class SRGANModel(BaseModel):
|
||||||
self.netD_grad = DistributedDataParallel(self.netD_grad,
|
self.netD_grad = DistributedDataParallel(self.netD_grad,
|
||||||
device_ids=[torch.cuda.current_device()],
|
device_ids=[torch.cuda.current_device()],
|
||||||
find_unused_parameters=True)
|
find_unused_parameters=True)
|
||||||
|
self.get_grad_nopadding = DistributedDataParallel(self.get_grad_nopadding,
|
||||||
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
self.netD = DataParallel(self.netD)
|
self.netD = DataParallel(self.netD)
|
||||||
if self.spsr_enabled:
|
if self.spsr_enabled:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user