Add pix_grad_branch loss to metrics

This commit is contained in:
James Betker 2020-08-03 16:21:05 -06:00
parent 0d070b47a7
commit c7e5d3888a

View File

@ -726,6 +726,8 @@ class SRGANModel(BaseModel):
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
if self.spsr_enabled:
if self.cri_pix_grad:
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.item())
if self.cri_pix_branch:
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
if self.l_gan_w > 0 and step >= self.G_warmup: