From 0098663b6b2262b24a13d109b906f622697ce405 Mon Sep 17 00:00:00 2001 From: XintaoWang Date: Sun, 1 Sep 2019 22:14:29 +0800 Subject: [PATCH] SRGAN model supprots dist training --- codes/models/SRGAN_model.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 77bb0fca..051f5076 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -69,8 +69,7 @@ class SRGANModel(BaseModel): if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: - self.netF = DistributedDataParallel(self.netF, - device_ids=[torch.cuda.current_device()]) + pass # do not need to use DistributedDataParallel for netF else: self.netF = DataParallel(self.netF) @@ -151,11 +150,12 @@ class SRGANModel(BaseModel): l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea - pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': + pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() + pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 @@ -169,19 +169,30 @@ class SRGANModel(BaseModel): p.requires_grad = True self.optimizer_D.zero_grad() - l_d_total = 0 - pred_d_real = self.netD(self.var_ref) - pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': + # need to forward and backward separately, since batch norm statistics differ + # real + pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) + l_d_real.backward() + # fake + pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) - l_d_total = l_d_real + l_d_fake + l_d_fake.backward() elif self.opt['train']['gan_type'] == 'ragan': - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) - l_d_total = (l_d_real + l_d_fake) / 2 - - l_d_total.backward() + # pred_d_real = self.netD(self.var_ref) + # pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G + # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + # l_d_total = (l_d_real + l_d_fake) / 2 + # l_d_total.backward() + pred_d_fake = self.netD(self.fake_H.detach()).detach() + pred_d_real = self.netD(self.var_ref) + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 + l_d_real.backward() + pred_d_fake = self.netD(self.fake_H.detach()) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 + l_d_fake.backward() self.optimizer_D.step() # set log