Fix reporting of d_fake_diff for generators
This commit is contained in:
parent
3676f26d94
commit
a51daacde2
|
@ -176,10 +176,10 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
if self.detach_real:
|
if self.detach_real:
|
||||||
pred_d_real = pred_d_real.detach()
|
pred_d_real = pred_d_real.detach()
|
||||||
pred_g_fake = netD(*fake)
|
pred_g_fake = netD(*fake)
|
||||||
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
|
d_fake_diff = pred_g_fake - torch.mean(pred_d_real)
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
d_fake_diff) / 2
|
self.criterion(d_fake_diff, True)) / 2
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if self.min_loss != 0:
|
if self.min_loss != 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user