forked from mrq/DL-Art-School
Fix distributed d_grad
This commit is contained in:
parent
09a9079e17
commit
96586d6592
|
@ -288,8 +288,14 @@ class SRGANModel(BaseModel):
|
|||
self.netD = DistributedDataParallel(self.netD,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=True)
|
||||
if self.spsr_enabled:
|
||||
self.netD_grad = DistributedDataParallel(self.netD_grad,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=True)
|
||||
else:
|
||||
self.netD = DataParallel(self.netD)
|
||||
if self.spsr_enabled:
|
||||
self.netD_grad = DataParallel(self.netD_grad)
|
||||
self.netG.train()
|
||||
self.netD.train()
|
||||
if self.spsr_enabled:
|
||||
|
|
Loading…
Reference in New Issue
Block a user