From ebda70fcbaf70a9cae45f612fd25105f607863f7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 22 Apr 2020 00:39:31 -0600 Subject: [PATCH] Fix AMP --- codes/models/SRGAN_model.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 40241f8b..69dfb04c 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -26,20 +26,8 @@ class SRGANModel(BaseModel): # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) - if opt['dist']: - self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) - else: - self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) - if opt['dist']: - self.netD = DistributedDataParallel(self.netD, - device_ids=[torch.cuda.current_device()]) - else: - self.netD = DataParallel(self.netD) - - self.netG.train() - self.netD.train() # define losses, optimizer and scheduler if self.is_train: @@ -109,6 +97,20 @@ class SRGANModel(BaseModel): [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \ amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3) + # DataParallel + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + if self.is_train: + if opt['dist']: + self.netD = DistributedDataParallel(self.netD, + device_ids=[torch.cuda.current_device()]) + else: + self.netD = DataParallel(self.netD) + self.netG.train() + self.netD.train() + # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: