Only log discriminator data when gan is activated

This commit is contained in:
James Betker 2020-06-01 15:48:16 -06:00
parent f1a1fd14b1
commit 8355f3d1b3

View File

@ -347,8 +347,10 @@ class SRGANModel(BaseModel):
if self.cri_fea: if self.cri_fea:
self.add_log_entry('feature_weight', self.l_fea_w) self.add_log_entry('feature_weight', self.l_fea_w)
self.add_log_entry('l_g_fea', l_g_fea.item()) self.add_log_entry('l_g_fea', l_g_fea.item())
if self.l_gan_w > 0:
self.add_log_entry('l_g_gan', l_g_gan.item()) self.add_log_entry('l_g_gan', l_g_gan.item())
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor) self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor)
if self.l_gan_w > 0:
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))