Update how branch GAN grad is disseminated

This commit is contained in:
James Betker 2020-08-06 11:13:02 -06:00
parent 1f21c02f8b
commit 30b16d5235

View File

@ -159,6 +159,8 @@ class SRGANModel(BaseModel):
# GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
self.l_gan_w = train_opt['gan_weight']
if train_opt['gan_type'] == 'pixgan':
self.do_pixgan_swap = True if 'do_pixgan_swap' not in train_opt.keys() else train_opt['do_pixgan_swap']
# D_update_ratio and D_init_iters
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
@ -431,12 +433,12 @@ class SRGANModel(BaseModel):
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding)
l_g_total += l_g_pix_grad
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
# The point of this loss is that the core structure of the grad image does not get mutated. Therefore,
# downsample and compare against the input. The GAN loss will take care of the details in HR-space.
var_L_grad = self.get_grad_nopadding(var_L)
downsampled_H_branch = F.interpolate(fake_H_branch, size=var_L_grad.shape[2:], mode="nearest")
grad_truth = self.get_grad_nopadding(var_L)
downsampled_H_branch = fake_H_branch
if grad_truth.shape != fake_H_branch.shape:
downsampled_H_branch = F.interpolate(downsampled_H_branch, size=grad_truth.shape[2:], mode="nearest")
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch,
var_L_grad)
grad_truth)
l_g_total += l_g_pix_grad_branch
if self.fdpl_enabled and not using_gan_img:
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
@ -487,7 +489,8 @@ class SRGANModel(BaseModel):
pred_g_fake_grad_branch = self.netD_grad(fake_H_branch)
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
# Uncomment to compute a discriminator loss against the grad branch.
#l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach()
l_g_gan_grad = self.l_gan_w * (
@ -576,7 +579,7 @@ class SRGANModel(BaseModel):
b, _, w, h = var_ref.shape
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
if not self.disjoint_data:
if self.do_pixgan_swap and not self.disjoint_data:
# randomly determine portions of the image to swap to keep the discriminator honest.
SWAP_MAX_DIM = w // 4
SWAP_MIN_DIM = 16
@ -654,22 +657,27 @@ class SRGANModel(BaseModel):
for p in self.netD_grad.parameters():
p.requires_grad = True
self.optimizer_D_grad.zero_grad()
for var_ref, fake_H in zip(var_ref_skips, self.fake_H):
for var_ref, fake_H, fake_H_grad_branch in zip(var_ref_skips, self.fake_H, self.spsr_grad_GenOut):
fake_H_grad = self.get_grad_nopadding(fake_H).detach()
var_ref_grad = self.get_grad_nopadding(var_ref)
pred_d_real_grad = self.netD_grad(var_ref_grad)
pred_d_fake_grad = self.netD_grad(fake_H_grad) # detach to avoid BP to G
pred_d_fake_grad = self.netD_grad(fake_H_grad) # Tensor already detached above.
# var_ref and fake_H already has noise added to it. We **must** add noise to fake_H_grad_branch too.
fake_H_grad_branch = fake_H_grad_branch.detach() + noise
pred_d_fake_grad_branch = self.netD_grad(fake_H_grad_branch)
if self.opt['train']['gan_type'] == 'gan':
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
l_d_real_grad = self.cri_gan(pred_d_real_grad, True)
l_d_fake_grad = (self.cri_gan(pred_d_fake_grad, False) + self.cri_gan(pred_d_fake_grad_branch, False)) / 2
elif self.opt['train']['gan_type'] == 'pixgan':
real = torch.ones_like(pred_d_real_grad)
fake = torch.zeros_like(pred_d_fake_grad)
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad, fake) + \
self.cri_grad_gan(pred_d_fake_grad_branch, fake)) / 2
elif self.opt['train']['gan_type'] == 'ragan':
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + \
self.cri_grad_gan(pred_d_fake_grad_branch - torch.mean(pred_d_real_grad), False)) / 2
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
l_d_total_grad /= self.mega_batch_factor