From b28e4d9cc7b9087e53432fea8da6b51b068612df Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 19 Oct 2020 11:31:19 -0600 Subject: [PATCH] Add spread loss Experimental loss that peaks around 0. --- codes/models/loss.py | 10 ++++++++++ codes/models/steps/losses.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/codes/models/loss.py b/codes/models/loss.py index ef387144..ef61b84c 100644 --- a/codes/models/loss.py +++ b/codes/models/loss.py @@ -18,6 +18,14 @@ class CharbonnierLoss(nn.Module): return loss +class ZeroSpreadLoss(nn.Module): + def __init__(self): + super(ZeroSpreadLoss, self).__init__() + + def forward(self, x, _): + return 2 * torch.nn.functional.sigmoid(1 / torch.abs(torch.mean(x))) - 1 + + # Define GAN loss: [vanilla | lsgan] class GANLoss(nn.Module): def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): @@ -30,6 +38,8 @@ class GANLoss(nn.Module): self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() + elif self.gan_type == 'max_spread': + self.loss = ZeroSpreadLoss() else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 5b2cb41c..d70a2ee6 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -226,7 +226,7 @@ class DiscriminatorGanLoss(ConfigurableLoss): l_fake = self.criterion(d_fake, False) l_total = l_real + l_fake loss = l_total - elif self.opt['gan_type'] == 'ragan': + elif self.opt['gan_type'] == 'ragan' or self.opt['gan_type'] == 'max_spread': d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) loss = (self.criterion(d_real - torch.mean(d_fake), True) +