Fix distributed d_grad

This commit is contained in:
James Betker 2020-08-25 16:06:27 -06:00
parent 09a9079e17
commit 96586d6592

View File

@ -288,8 +288,14 @@ class SRGANModel(BaseModel):
self.netD = DistributedDataParallel(self.netD,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
if self.spsr_enabled:
self.netD_grad = DistributedDataParallel(self.netD_grad,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
else:
self.netD = DataParallel(self.netD)
if self.spsr_enabled:
self.netD_grad = DataParallel(self.netD_grad)
self.netG.train()
self.netD.train()
if self.spsr_enabled: