This commit is contained in:
James Betker 2020-04-22 00:39:31 -06:00
parent f4b33b0531
commit ebda70fcba

View File

@ -26,20 +26,8 @@ class SRGANModel(BaseModel):
# define networks and load pretrained models
self.netG = networks.define_G(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
if self.is_train:
self.netD = networks.define_D(opt).to(self.device)
if opt['dist']:
self.netD = DistributedDataParallel(self.netD,
device_ids=[torch.cuda.current_device()])
else:
self.netD = DataParallel(self.netD)
self.netG.train()
self.netD.train()
# define losses, optimizer and scheduler
if self.is_train:
@ -109,6 +97,20 @@ class SRGANModel(BaseModel):
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
# DataParallel
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
if self.is_train:
if opt['dist']:
self.netD = DistributedDataParallel(self.netD,
device_ids=[torch.cuda.current_device()])
else:
self.netD = DataParallel(self.netD)
self.netG.train()
self.netD.train()
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers: