More loss fixes

This commit is contained in:
James Betker 2020-07-16 10:45:50 -06:00
parent 6cfa67d831
commit 240f254263

View File

@ -483,12 +483,13 @@ class SRGANModel(BaseModel):
if self.l_gan_w > 0 and step > self.G_warmup: if self.l_gan_w > 0 and step > self.G_warmup:
self.add_log_entry('l_d_real', l_d_real_log.item()) self.add_log_entry('l_d_real', l_d_real_log.item())
self.add_log_entry('l_d_fake', l_d_fake_log.item()) self.add_log_entry('l_d_fake', l_d_fake_log.item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
if isinstance(l_d_fea_real, torch.tensor):
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
if step % self.corruptor_swapout_steps == 0 and step > 0: if step % self.corruptor_swapout_steps == 0 and step > 0:
self.load_random_corruptor() self.load_random_corruptor()