Compute grad GAN loss against both the branch and final target, simplify pixel loss

Also fixes a memory leak issue where we weren't detaching our loss stats when
logging them. This stabilizes memory usage substantially.
This commit is contained in:
James Betker 2020-08-05 12:08:15 -06:00
parent 299ee13988
commit 26a6a5d512

View File

@ -393,14 +393,13 @@ class SRGANModel(BaseModel):
var_ref_skips = []
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
if self.spsr_enabled:
using_gan_img = False
# SPSR models have outputs from three different branches.
fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L)
fea_GenOut = fake_GenOut
using_gan_img = False
self.spsr_grad_GenOut.append(fake_H_branch)
# Get image gradients for later use.
fake_H_grad = self.get_grad_nopadding(fake_GenOut)
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
self.spsr_grad_GenOut.append(fake_H_branch)
else:
if random.random() > self.gan_lq_img_use_prob:
fea_GenOut, fake_GenOut = self.netG(var_L)
@ -428,11 +427,16 @@ class SRGANModel(BaseModel):
l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix
if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding)
l_g_total += l_g_pix_grad
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch,
var_H_grad_nopadding)
# The point of this loss is that the core structure of the grad image does not get mutated. Therefore,
# downsample and compare against the input. The GAN loss will take care of the details in HR-space.
var_L_grad = self.get_grad_nopadding(var_L)
downsampled_H_branch = F.interpolate(fake_H_branch, size=var_L_grad.shape[2:], mode="nearest")
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch,
var_L_grad)
l_g_total += l_g_pix_grad_branch
if self.fdpl_enabled and not using_gan_img:
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
@ -480,14 +484,19 @@ class SRGANModel(BaseModel):
if self.spsr_enabled and self.cri_grad_gan:
pred_g_fake_grad = self.netD_grad(fake_H_grad)
pred_g_fake_grad_branch = self.netD_grad(fake_H_branch)
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach()
l_g_gan_grad = self.l_gan_w * (
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
l_g_total += l_g_gan_grad
l_g_gan_grad_branch = self.l_gan_w * (
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) +
self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2
l_g_total += l_g_gan_grad + l_g_gan_grad_branch
# Scale the loss down by the batch factor.
l_g_total_log = l_g_total
@ -703,36 +712,39 @@ class SRGANModel(BaseModel):
# Log metrics
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
if self.cri_pix and l_g_pix_log is not None:
self.add_log_entry('l_g_pix', l_g_pix_log.item())
self.add_log_entry('l_g_pix', l_g_pix_log.detach().item())
if self.fdpl_enabled and l_g_fdpl is not None:
self.add_log_entry('l_g_fdpl', l_g_fdpl.item())
self.add_log_entry('l_g_fdpl', l_g_fdpl.detach().item())
if self.cri_fea and l_g_fea_log is not None:
self.add_log_entry('feature_weight', fea_w)
self.add_log_entry('l_g_fea', l_g_fea_log.item())
self.add_log_entry('l_g_fix_disc', l_g_fix_disc.item())
self.add_log_entry('l_g_fea', l_g_fea_log.detach().item())
self.add_log_entry('l_g_fix_disc', l_g_fix_disc.detach().item())
if self.l_gan_w > 0:
self.add_log_entry('l_g_gan', l_g_gan_log.item())
self.add_log_entry('l_g_total', l_g_total_log.item())
self.add_log_entry('l_g_gan', l_g_gan_log.detach().item())
self.add_log_entry('l_g_total', l_g_total_log.detach().item())
if self.opt['train']['gan_type'] == 'pixgan_fea':
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.detach().item() * self.mega_batch_factor)
self.add_log_entry('l_d_fea_real', l_d_fea_real.detach().item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake_total', l_d_fake.detach().item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.detach().item() * self.mega_batch_factor)
if self.spsr_enabled:
if self.cri_pix_grad:
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.item())
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.detach().item())
if self.cri_pix_branch:
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item())
if self.cri_grad_gan:
self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item())
self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item())
if self.l_gan_w > 0 and step >= self.G_warmup:
self.add_log_entry('l_d_real', l_d_real_log.item())
self.add_log_entry('l_d_fake', l_d_fake_log.item())
self.add_log_entry('l_d_real', l_d_real_log.detach().item())
self.add_log_entry('l_d_fake', l_d_fake_log.detach().item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
self.add_log_entry('D_diff', torch.mean(pred_d_fake.detach()) - torch.mean(pred_d_real.detach()))
if self.spsr_enabled:
self.add_log_entry('l_d_real_grad', l_d_real_grad.item())
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item())
self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item())
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad) - torch.mean(pred_d_real_grad))
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach()))
# Log learning rates.
for i, pg in enumerate(self.optimizer_G.param_groups):