Compute grad GAN loss against both the branch and final target, simplify pixel loss
Also fixes a memory leak issue where we weren't detaching our loss stats when logging them. This stabilizes memory usage substantially.
This commit is contained in:
parent
299ee13988
commit
26a6a5d512
|
@ -393,14 +393,13 @@ class SRGANModel(BaseModel):
|
|||
var_ref_skips = []
|
||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
if self.spsr_enabled:
|
||||
using_gan_img = False
|
||||
# SPSR models have outputs from three different branches.
|
||||
fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L)
|
||||
fea_GenOut = fake_GenOut
|
||||
using_gan_img = False
|
||||
self.spsr_grad_GenOut.append(fake_H_branch)
|
||||
# Get image gradients for later use.
|
||||
fake_H_grad = self.get_grad_nopadding(fake_GenOut)
|
||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||
self.spsr_grad_GenOut.append(fake_H_branch)
|
||||
else:
|
||||
if random.random() > self.gan_lq_img_use_prob:
|
||||
fea_GenOut, fake_GenOut = self.netG(var_L)
|
||||
|
@ -428,11 +427,16 @@ class SRGANModel(BaseModel):
|
|||
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||
l_g_total += l_g_pix
|
||||
if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss
|
||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding)
|
||||
l_g_total += l_g_pix_grad
|
||||
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch,
|
||||
var_H_grad_nopadding)
|
||||
# The point of this loss is that the core structure of the grad image does not get mutated. Therefore,
|
||||
# downsample and compare against the input. The GAN loss will take care of the details in HR-space.
|
||||
var_L_grad = self.get_grad_nopadding(var_L)
|
||||
downsampled_H_branch = F.interpolate(fake_H_branch, size=var_L_grad.shape[2:], mode="nearest")
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch,
|
||||
var_L_grad)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
if self.fdpl_enabled and not using_gan_img:
|
||||
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
|
||||
|
@ -480,14 +484,19 @@ class SRGANModel(BaseModel):
|
|||
|
||||
if self.spsr_enabled and self.cri_grad_gan:
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
pred_g_fake_grad_branch = self.netD_grad(fake_H_branch)
|
||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach()
|
||||
l_g_gan_grad = self.l_gan_w * (
|
||||
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
|
||||
l_g_total += l_g_gan_grad
|
||||
l_g_gan_grad_branch = self.l_gan_w * (
|
||||
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) +
|
||||
self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2
|
||||
l_g_total += l_g_gan_grad + l_g_gan_grad_branch
|
||||
|
||||
# Scale the loss down by the batch factor.
|
||||
l_g_total_log = l_g_total
|
||||
|
@ -703,36 +712,39 @@ class SRGANModel(BaseModel):
|
|||
# Log metrics
|
||||
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||
if self.cri_pix and l_g_pix_log is not None:
|
||||
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
||||
self.add_log_entry('l_g_pix', l_g_pix_log.detach().item())
|
||||
if self.fdpl_enabled and l_g_fdpl is not None:
|
||||
self.add_log_entry('l_g_fdpl', l_g_fdpl.item())
|
||||
self.add_log_entry('l_g_fdpl', l_g_fdpl.detach().item())
|
||||
if self.cri_fea and l_g_fea_log is not None:
|
||||
self.add_log_entry('feature_weight', fea_w)
|
||||
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
||||
self.add_log_entry('l_g_fix_disc', l_g_fix_disc.item())
|
||||
self.add_log_entry('l_g_fea', l_g_fea_log.detach().item())
|
||||
self.add_log_entry('l_g_fix_disc', l_g_fix_disc.detach().item())
|
||||
if self.l_gan_w > 0:
|
||||
self.add_log_entry('l_g_gan', l_g_gan_log.item())
|
||||
self.add_log_entry('l_g_total', l_g_total_log.item())
|
||||
self.add_log_entry('l_g_gan', l_g_gan_log.detach().item())
|
||||
self.add_log_entry('l_g_total', l_g_total_log.detach().item())
|
||||
if self.opt['train']['gan_type'] == 'pixgan_fea':
|
||||
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.detach().item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fea_real', l_d_fea_real.detach().item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake_total', l_d_fake.detach().item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real_total', l_d_real.detach().item() * self.mega_batch_factor)
|
||||
if self.spsr_enabled:
|
||||
if self.cri_pix_grad:
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.item())
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.detach().item())
|
||||
if self.cri_pix_branch:
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item())
|
||||
if self.cri_grad_gan:
|
||||
self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item())
|
||||
self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item())
|
||||
if self.l_gan_w > 0 and step >= self.G_warmup:
|
||||
self.add_log_entry('l_d_real', l_d_real_log.item())
|
||||
self.add_log_entry('l_d_fake', l_d_fake_log.item())
|
||||
self.add_log_entry('l_d_real', l_d_real_log.detach().item())
|
||||
self.add_log_entry('l_d_fake', l_d_fake_log.detach().item())
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake.detach()) - torch.mean(pred_d_real.detach()))
|
||||
if self.spsr_enabled:
|
||||
self.add_log_entry('l_d_real_grad', l_d_real_grad.item())
|
||||
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item())
|
||||
self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item())
|
||||
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item())
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach()))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad) - torch.mean(pred_d_real_grad))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach()))
|
||||
|
||||
# Log learning rates.
|
||||
for i, pg in enumerate(self.optimizer_G.param_groups):
|
||||
|
|
Loading…
Reference in New Issue
Block a user