More RAGAN fixes
This commit is contained in:
parent
26a6a5d512
commit
be272248af
|
@ -489,7 +489,7 @@ class SRGANModel(BaseModel):
|
||||||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||||
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
|
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach()
|
pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach()
|
||||||
l_g_gan_grad = self.l_gan_w * (
|
l_g_gan_grad = self.l_gan_w * (
|
||||||
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
|
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||||
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
|
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
|
||||||
|
@ -631,7 +631,7 @@ class SRGANModel(BaseModel):
|
||||||
fake_disc_images.append(pdf.view(disc_output_shape))
|
fake_disc_images.append(pdf.view(disc_output_shape))
|
||||||
|
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
pred_d_fake = self.netD(fake_H).detach()
|
pred_d_fake = self.netD(fake_H)
|
||||||
pred_d_real = self.netD(var_ref)
|
pred_d_real = self.netD(var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||||
l_d_real_log = l_d_real
|
l_d_real_log = l_d_real
|
||||||
|
@ -655,10 +655,10 @@ class SRGANModel(BaseModel):
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
self.optimizer_D_grad.zero_grad()
|
self.optimizer_D_grad.zero_grad()
|
||||||
for var_ref, fake_H in zip(var_ref_skips, self.fake_H):
|
for var_ref, fake_H in zip(var_ref_skips, self.fake_H):
|
||||||
fake_H_grad = self.get_grad_nopadding(fake_H)
|
fake_H_grad = self.get_grad_nopadding(fake_H).detach()
|
||||||
var_ref_grad = self.get_grad_nopadding(var_ref)
|
var_ref_grad = self.get_grad_nopadding(var_ref)
|
||||||
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
||||||
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
|
pred_d_fake_grad = self.netD_grad(fake_H_grad) # detach to avoid BP to G
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
|
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
|
||||||
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
|
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
|
||||||
|
@ -668,10 +668,8 @@ class SRGANModel(BaseModel):
|
||||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
|
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
|
||||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
|
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||||
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
|
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), True)
|
|
||||||
l_d_fake_grad = self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), False)
|
|
||||||
|
|
||||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||||
l_d_total_grad /= self.mega_batch_factor
|
l_d_total_grad /= self.mega_batch_factor
|
||||||
|
@ -743,8 +741,8 @@ class SRGANModel(BaseModel):
|
||||||
if self.spsr_enabled:
|
if self.spsr_enabled:
|
||||||
self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item())
|
self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item())
|
||||||
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item())
|
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item())
|
||||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach()))
|
self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach()))
|
||||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach()))
|
self.add_log_entry('D_diff_grad', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach()))
|
||||||
|
|
||||||
# Log learning rates.
|
# Log learning rates.
|
||||||
for i, pg in enumerate(self.optimizer_G.param_groups):
|
for i, pg in enumerate(self.optimizer_G.param_groups):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user