Add spread loss

Experimental loss that peaks around 0.
This commit is contained in:
James Betker 2020-10-19 11:31:19 -06:00
parent 9b9a6e5925
commit b28e4d9cc7
2 changed files with 11 additions and 1 deletions

View File

@ -18,6 +18,14 @@ class CharbonnierLoss(nn.Module):
return loss
class ZeroSpreadLoss(nn.Module):
def __init__(self):
super(ZeroSpreadLoss, self).__init__()
def forward(self, x, _):
return 2 * torch.nn.functional.sigmoid(1 / torch.abs(torch.mean(x))) - 1
# Define GAN loss: [vanilla | lsgan]
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
@ -30,6 +38,8 @@ class GANLoss(nn.Module):
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'max_spread':
self.loss = ZeroSpreadLoss()
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))

View File

@ -226,7 +226,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
l_fake = self.criterion(d_fake, False)
l_total = l_real + l_fake
loss = l_total
elif self.opt['gan_type'] == 'ragan':
elif self.opt['gan_type'] == 'ragan' or self.opt['gan_type'] == 'max_spread':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
loss = (self.criterion(d_real - torch.mean(d_fake), True) +