Update how branch GAN grad is disseminated
This commit is contained in:
parent
1f21c02f8b
commit
30b16d5235
|
@ -159,6 +159,8 @@ class SRGANModel(BaseModel):
|
|||
# GD gan loss
|
||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_w = train_opt['gan_weight']
|
||||
if train_opt['gan_type'] == 'pixgan':
|
||||
self.do_pixgan_swap = True if 'do_pixgan_swap' not in train_opt.keys() else train_opt['do_pixgan_swap']
|
||||
# D_update_ratio and D_init_iters
|
||||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||
|
@ -431,12 +433,12 @@ class SRGANModel(BaseModel):
|
|||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding)
|
||||
l_g_total += l_g_pix_grad
|
||||
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
|
||||
# The point of this loss is that the core structure of the grad image does not get mutated. Therefore,
|
||||
# downsample and compare against the input. The GAN loss will take care of the details in HR-space.
|
||||
var_L_grad = self.get_grad_nopadding(var_L)
|
||||
downsampled_H_branch = F.interpolate(fake_H_branch, size=var_L_grad.shape[2:], mode="nearest")
|
||||
grad_truth = self.get_grad_nopadding(var_L)
|
||||
downsampled_H_branch = fake_H_branch
|
||||
if grad_truth.shape != fake_H_branch.shape:
|
||||
downsampled_H_branch = F.interpolate(downsampled_H_branch, size=grad_truth.shape[2:], mode="nearest")
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch,
|
||||
var_L_grad)
|
||||
grad_truth)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
if self.fdpl_enabled and not using_gan_img:
|
||||
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
|
||||
|
@ -487,7 +489,8 @@ class SRGANModel(BaseModel):
|
|||
pred_g_fake_grad_branch = self.netD_grad(fake_H_branch)
|
||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
|
||||
# Uncomment to compute a discriminator loss against the grad branch.
|
||||
#l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach()
|
||||
l_g_gan_grad = self.l_gan_w * (
|
||||
|
@ -576,7 +579,7 @@ class SRGANModel(BaseModel):
|
|||
b, _, w, h = var_ref.shape
|
||||
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
if not self.disjoint_data:
|
||||
if self.do_pixgan_swap and not self.disjoint_data:
|
||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||
SWAP_MAX_DIM = w // 4
|
||||
SWAP_MIN_DIM = 16
|
||||
|
@ -654,22 +657,27 @@ class SRGANModel(BaseModel):
|
|||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = True
|
||||
self.optimizer_D_grad.zero_grad()
|
||||
for var_ref, fake_H in zip(var_ref_skips, self.fake_H):
|
||||
for var_ref, fake_H, fake_H_grad_branch in zip(var_ref_skips, self.fake_H, self.spsr_grad_GenOut):
|
||||
fake_H_grad = self.get_grad_nopadding(fake_H).detach()
|
||||
var_ref_grad = self.get_grad_nopadding(var_ref)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
||||
pred_d_fake_grad = self.netD_grad(fake_H_grad) # detach to avoid BP to G
|
||||
pred_d_fake_grad = self.netD_grad(fake_H_grad) # Tensor already detached above.
|
||||
# var_ref and fake_H already has noise added to it. We **must** add noise to fake_H_grad_branch too.
|
||||
fake_H_grad_branch = fake_H_grad_branch.detach() + noise
|
||||
pred_d_fake_grad_branch = self.netD_grad(fake_H_grad_branch)
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
|
||||
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
|
||||
l_d_real_grad = self.cri_gan(pred_d_real_grad, True)
|
||||
l_d_fake_grad = (self.cri_gan(pred_d_fake_grad, False) + self.cri_gan(pred_d_fake_grad_branch, False)) / 2
|
||||
elif self.opt['train']['gan_type'] == 'pixgan':
|
||||
real = torch.ones_like(pred_d_real_grad)
|
||||
fake = torch.zeros_like(pred_d_fake_grad)
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
|
||||
l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad, fake) + \
|
||||
self.cri_grad_gan(pred_d_fake_grad_branch, fake)) / 2
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + \
|
||||
self.cri_grad_gan(pred_d_fake_grad_branch - torch.mean(pred_d_real_grad), False)) / 2
|
||||
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
l_d_total_grad /= self.mega_batch_factor
|
||||
|
|
Loading…
Reference in New Issue
Block a user