Fix reporting of d_fake_diff for generators
This commit is contained in:
parent
3676f26d94
commit
a51daacde2
|
@ -176,10 +176,10 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
if self.detach_real:
|
||||
pred_d_real = pred_d_real.detach()
|
||||
pred_g_fake = netD(*fake)
|
||||
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
|
||||
d_fake_diff = pred_g_fake - torch.mean(pred_d_real)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
d_fake_diff) / 2
|
||||
self.criterion(d_fake_diff, True)) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.min_loss != 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user