diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 9af996f0..a1419a89 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -224,10 +224,9 @@ class SRGANModel(BaseModel): if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() - + self.log_dict['l_g_total'] = l_g_total.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() - self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self):