Fix reporting of d_fake_diff for generators

This commit is contained in:
James Betker 2020-11-02 08:45:46 -07:00
parent 3676f26d94
commit a51daacde2

View File

@ -176,10 +176,10 @@ class GeneratorGanLoss(ConfigurableLoss):
if self.detach_real: if self.detach_real:
pred_d_real = pred_d_real.detach() pred_d_real = pred_d_real.detach()
pred_g_fake = netD(*fake) pred_g_fake = netD(*fake)
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True) d_fake_diff = pred_g_fake - torch.mean(pred_d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
d_fake_diff) / 2 self.criterion(d_fake_diff, True)) / 2
else: else:
raise NotImplementedError raise NotImplementedError
if self.min_loss != 0: if self.min_loss != 0: