forked from mrq/DL-Art-School
Fix AMP
This commit is contained in:
parent
f4b33b0531
commit
ebda70fcba
|
@ -26,20 +26,8 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
# define networks and load pretrained models
|
# define networks and load pretrained models
|
||||||
self.netG = networks.define_G(opt).to(self.device)
|
self.netG = networks.define_G(opt).to(self.device)
|
||||||
if opt['dist']:
|
|
||||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
|
||||||
else:
|
|
||||||
self.netG = DataParallel(self.netG)
|
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.netD = networks.define_D(opt).to(self.device)
|
self.netD = networks.define_D(opt).to(self.device)
|
||||||
if opt['dist']:
|
|
||||||
self.netD = DistributedDataParallel(self.netD,
|
|
||||||
device_ids=[torch.cuda.current_device()])
|
|
||||||
else:
|
|
||||||
self.netD = DataParallel(self.netD)
|
|
||||||
|
|
||||||
self.netG.train()
|
|
||||||
self.netD.train()
|
|
||||||
|
|
||||||
# define losses, optimizer and scheduler
|
# define losses, optimizer and scheduler
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
@ -109,6 +97,20 @@ class SRGANModel(BaseModel):
|
||||||
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
|
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
|
||||||
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
|
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
|
||||||
|
|
||||||
|
# DataParallel
|
||||||
|
if opt['dist']:
|
||||||
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
self.netG = DataParallel(self.netG)
|
||||||
|
if self.is_train:
|
||||||
|
if opt['dist']:
|
||||||
|
self.netD = DistributedDataParallel(self.netD,
|
||||||
|
device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
self.netD = DataParallel(self.netD)
|
||||||
|
self.netG.train()
|
||||||
|
self.netD.train()
|
||||||
|
|
||||||
# schedulers
|
# schedulers
|
||||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||||
for optimizer in self.optimizers:
|
for optimizer in self.optimizers:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user