diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 5ea6c113..d7efdfb6 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -176,10 +176,10 @@ class GeneratorGanLoss(ConfigurableLoss): if self.detach_real: pred_d_real = pred_d_real.detach() pred_g_fake = netD(*fake) - d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True) + d_fake_diff = pred_g_fake - torch.mean(pred_d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + - d_fake_diff) / 2 + self.criterion(d_fake_diff, True)) / 2 else: raise NotImplementedError if self.min_loss != 0: