forked from mrq/DL-Art-School
Add spread loss
Experimental loss that peaks around 0.
This commit is contained in:
parent
9b9a6e5925
commit
b28e4d9cc7
|
@ -18,6 +18,14 @@ class CharbonnierLoss(nn.Module):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroSpreadLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(ZeroSpreadLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, _):
|
||||||
|
return 2 * torch.nn.functional.sigmoid(1 / torch.abs(torch.mean(x))) - 1
|
||||||
|
|
||||||
|
|
||||||
# Define GAN loss: [vanilla | lsgan]
|
# Define GAN loss: [vanilla | lsgan]
|
||||||
class GANLoss(nn.Module):
|
class GANLoss(nn.Module):
|
||||||
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||||
|
@ -30,6 +38,8 @@ class GANLoss(nn.Module):
|
||||||
self.loss = nn.BCEWithLogitsLoss()
|
self.loss = nn.BCEWithLogitsLoss()
|
||||||
elif self.gan_type == 'lsgan':
|
elif self.gan_type == 'lsgan':
|
||||||
self.loss = nn.MSELoss()
|
self.loss = nn.MSELoss()
|
||||||
|
elif self.gan_type == 'max_spread':
|
||||||
|
self.loss = ZeroSpreadLoss()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||||
|
|
||||||
|
|
|
@ -226,7 +226,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
l_fake = self.criterion(d_fake, False)
|
l_fake = self.criterion(d_fake, False)
|
||||||
l_total = l_real + l_fake
|
l_total = l_real + l_fake
|
||||||
loss = l_total
|
loss = l_total
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan' or self.opt['gan_type'] == 'max_spread':
|
||||||
d_fake_diff = d_fake - torch.mean(d_real)
|
d_fake_diff = d_fake - torch.mean(d_real)
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
loss = (self.criterion(d_real - torch.mean(d_fake), True) +
|
loss = (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||||
|
|
Loading…
Reference in New Issue
Block a user