diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 19665e67..fa1dea4d 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -347,13 +347,15 @@ class SRGANModel(BaseModel): if self.cri_fea: self.add_log_entry('feature_weight', self.l_fea_w) self.add_log_entry('l_g_fea', l_g_fea.item()) - self.add_log_entry('l_g_gan', l_g_gan.item()) + if self.l_gan_w > 0: + self.add_log_entry('l_g_gan', l_g_gan.item()) self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor) - self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) - self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) - self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) - self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) - self.add_log_entry('noise_theta', noise_theta) + if self.l_gan_w > 0: + self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) + self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) + self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) + self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) + self.add_log_entry('noise_theta', noise_theta) if step % self.corruptor_swapout_steps == 0 and step > 0: self.load_random_corruptor()