Log total gen loss

This commit is contained in:
James Betker 2020-04-22 14:02:10 -06:00
parent 79aff886b5
commit ea5f432f5a

View File

@ -224,10 +224,9 @@ class SRGANModel(BaseModel):
if self.cri_fea: if self.cri_fea:
self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_fea'] = l_g_fea.item()
self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_g_gan'] = l_g_gan.item()
self.log_dict['l_g_total'] = l_g_total.item()
self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_real'] = l_d_real.item()
self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['l_d_fake'] = l_d_fake.item()
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
def test(self): def test(self):