forked from mrq/DL-Art-School
Add spread loss
Experimental loss that peaks around 0.
This commit is contained in:
parent
9b9a6e5925
commit
b28e4d9cc7
|
@ -18,6 +18,14 @@ class CharbonnierLoss(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
class ZeroSpreadLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(ZeroSpreadLoss, self).__init__()
|
||||
|
||||
def forward(self, x, _):
|
||||
return 2 * torch.nn.functional.sigmoid(1 / torch.abs(torch.mean(x))) - 1
|
||||
|
||||
|
||||
# Define GAN loss: [vanilla | lsgan]
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||
|
@ -30,6 +38,8 @@ class GANLoss(nn.Module):
|
|||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif self.gan_type == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif self.gan_type == 'max_spread':
|
||||
self.loss = ZeroSpreadLoss()
|
||||
else:
|
||||
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||
|
||||
|
|
|
@ -226,7 +226,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
l_fake = self.criterion(d_fake, False)
|
||||
l_total = l_real + l_fake
|
||||
loss = l_total
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
elif self.opt['gan_type'] == 'ragan' or self.opt['gan_type'] == 'max_spread':
|
||||
d_fake_diff = d_fake - torch.mean(d_real)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
loss = (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||
|
|
Loading…
Reference in New Issue
Block a user