From 240f2542632b014193e5cd740dd05481af304153 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Jul 2020 10:45:50 -0600 Subject: [PATCH] More loss fixes --- codes/models/SRGAN_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 26f1b589..6402efcd 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -483,12 +483,13 @@ class SRGANModel(BaseModel): if self.l_gan_w > 0 and step > self.G_warmup: self.add_log_entry('l_d_real', l_d_real_log.item()) self.add_log_entry('l_d_fake', l_d_fake_log.item()) + self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) + self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) + if isinstance(l_d_fea_real, torch.tensor): self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) - self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) - self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) if step % self.corruptor_swapout_steps == 0 and step > 0: self.load_random_corruptor()