SRGAN model supprots dist training
This commit is contained in:
parent
9d949b838e
commit
0098663b6b
|
@ -69,8 +69,7 @@ class SRGANModel(BaseModel):
|
||||||
if self.cri_fea: # load VGG perceptual loss
|
if self.cri_fea: # load VGG perceptual loss
|
||||||
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.netF = DistributedDataParallel(self.netF,
|
pass # do not need to use DistributedDataParallel for netF
|
||||||
device_ids=[torch.cuda.current_device()])
|
|
||||||
else:
|
else:
|
||||||
self.netF = DataParallel(self.netF)
|
self.netF = DataParallel(self.netF)
|
||||||
|
|
||||||
|
@ -151,11 +150,12 @@ class SRGANModel(BaseModel):
|
||||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||||
l_g_total += l_g_fea
|
l_g_total += l_g_fea
|
||||||
|
|
||||||
pred_g_fake = self.netD(self.fake_H)
|
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
|
pred_g_fake = self.netD(self.fake_H)
|
||||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
pred_d_real = self.netD(self.var_ref).detach()
|
pred_d_real = self.netD(self.var_ref).detach()
|
||||||
|
pred_g_fake = self.netD(self.fake_H)
|
||||||
l_g_gan = self.l_gan_w * (
|
l_g_gan = self.l_gan_w * (
|
||||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
|
@ -169,19 +169,30 @@ class SRGANModel(BaseModel):
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
l_d_total = 0
|
|
||||||
pred_d_real = self.netD(self.var_ref)
|
|
||||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
|
# need to forward and backward separately, since batch norm statistics differ
|
||||||
|
# real
|
||||||
|
pred_d_real = self.netD(self.var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real, True)
|
l_d_real = self.cri_gan(pred_d_real, True)
|
||||||
|
l_d_real.backward()
|
||||||
|
# fake
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||||
l_d_total = l_d_real + l_d_fake
|
l_d_fake.backward()
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
# pred_d_real = self.netD(self.var_ref)
|
||||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||||
l_d_total = (l_d_real + l_d_fake) / 2
|
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||||
|
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||||
l_d_total.backward()
|
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||||
|
# l_d_total.backward()
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach()).detach()
|
||||||
|
pred_d_real = self.netD(self.var_ref)
|
||||||
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||||
|
l_d_real.backward()
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach())
|
||||||
|
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||||
|
l_d_fake.backward()
|
||||||
self.optimizer_D.step()
|
self.optimizer_D.step()
|
||||||
|
|
||||||
# set log
|
# set log
|
||||||
|
|
Loading…
Reference in New Issue
Block a user