SRGAN model supprots dist training

This commit is contained in:
XintaoWang 2019-09-01 22:14:29 +08:00
parent 9d949b838e
commit 0098663b6b

View File

@ -69,8 +69,7 @@ class SRGANModel(BaseModel):
if self.cri_fea: # load VGG perceptual loss
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
if opt['dist']:
self.netF = DistributedDataParallel(self.netF,
device_ids=[torch.cuda.current_device()])
pass # do not need to use DistributedDataParallel for netF
else:
self.netF = DataParallel(self.netF)
@ -151,11 +150,12 @@ class SRGANModel(BaseModel):
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea
pred_g_fake = self.netD(self.fake_H)
if self.opt['train']['gan_type'] == 'gan':
pred_g_fake = self.netD(self.fake_H)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(self.var_ref).detach()
pred_g_fake = self.netD(self.fake_H)
l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
@ -169,19 +169,30 @@ class SRGANModel(BaseModel):
p.requires_grad = True
self.optimizer_D.zero_grad()
l_d_total = 0
pred_d_real = self.netD(self.var_ref)
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ
# real
pred_d_real = self.netD(self.var_ref)
l_d_real = self.cri_gan(pred_d_real, True)
l_d_real.backward()
# fake
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False)
l_d_total = l_d_real + l_d_fake
l_d_fake.backward()
elif self.opt['train']['gan_type'] == 'ragan':
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
l_d_total = (l_d_real + l_d_fake) / 2
l_d_total.backward()
# pred_d_real = self.netD(self.var_ref)
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
# l_d_total = (l_d_real + l_d_fake) / 2
# l_d_total.backward()
pred_d_fake = self.netD(self.fake_H.detach()).detach()
pred_d_real = self.netD(self.var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
l_d_real.backward()
pred_d_fake = self.netD(self.fake_H.detach())
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
l_d_fake.backward()
self.optimizer_D.step()
# set log