SRGAN model supprots dist training
This commit is contained in:
parent
9d949b838e
commit
0098663b6b
|
@ -69,8 +69,7 @@ class SRGANModel(BaseModel):
|
|||
if self.cri_fea: # load VGG perceptual loss
|
||||
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
||||
if opt['dist']:
|
||||
self.netF = DistributedDataParallel(self.netF,
|
||||
device_ids=[torch.cuda.current_device()])
|
||||
pass # do not need to use DistributedDataParallel for netF
|
||||
else:
|
||||
self.netF = DataParallel(self.netF)
|
||||
|
||||
|
@ -151,11 +150,12 @@ class SRGANModel(BaseModel):
|
|||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(self.var_ref).detach()
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
|
@ -169,19 +169,30 @@ class SRGANModel(BaseModel):
|
|||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
l_d_total = 0
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True)
|
||||
l_d_real.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
l_d_total = l_d_real + l_d_fake
|
||||
l_d_fake.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
|
||||
l_d_total.backward()
|
||||
# pred_d_real = self.netD(self.var_ref)
|
||||
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(self.fake_H.detach()).detach()
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||
l_d_real.backward()
|
||||
pred_d_fake = self.netD(self.fake_H.detach())
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||
l_d_fake.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# set log
|
||||
|
|
Loading…
Reference in New Issue
Block a user