forked from mrq/DL-Art-School
26a4a66d1c
- Removed a bunch of unnecessary image loggers. These were just consuming space and never being viewed - Got rid of support of artificial var_ref support. The new pixdisc is what i wanted to implement then - it's much better. - Add pixgan GAN mechanism. This is purpose-built for the pixdisc. It is intended to promote a healthy discriminator - Megabatchfactor was applied twice on metrics, fixed that Adds pix_gan (untested) which swaps a portion of the fake and real image with each other, then expects the discriminator to properly discriminate the swapped regions.
78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class CharbonnierLoss(nn.Module):
|
|
"""Charbonnier Loss (L1)"""
|
|
|
|
def __init__(self, eps=1e-6):
|
|
super(CharbonnierLoss, self).__init__()
|
|
self.eps = eps
|
|
|
|
def forward(self, x, y):
|
|
diff = x - y
|
|
loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
|
return loss
|
|
|
|
|
|
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
|
class GANLoss(nn.Module):
|
|
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
|
super(GANLoss, self).__init__()
|
|
self.gan_type = gan_type.lower()
|
|
self.real_label_val = real_label_val
|
|
self.fake_label_val = fake_label_val
|
|
|
|
if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan':
|
|
self.loss = nn.BCEWithLogitsLoss()
|
|
elif self.gan_type == 'lsgan':
|
|
self.loss = nn.MSELoss()
|
|
elif self.gan_type == 'wgan-gp':
|
|
|
|
def wgan_loss(input, target):
|
|
# target is boolean
|
|
return -1 * input.mean() if target else input.mean()
|
|
|
|
self.loss = wgan_loss
|
|
else:
|
|
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
|
|
|
def get_target_label(self, input, target_is_real):
|
|
if self.gan_type == 'wgan-gp':
|
|
return target_is_real
|
|
if target_is_real:
|
|
return torch.empty_like(input).fill_(self.real_label_val)
|
|
else:
|
|
return torch.empty_like(input).fill_(self.fake_label_val)
|
|
|
|
def forward(self, input, target_is_real):
|
|
if self.gan_type == 'pixgan':
|
|
target_label = target_is_real
|
|
else:
|
|
target_label = self.get_target_label(input, target_is_real)
|
|
loss = self.loss(input, target_label)
|
|
return loss
|
|
|
|
|
|
class GradientPenaltyLoss(nn.Module):
|
|
def __init__(self, device=torch.device('cpu')):
|
|
super(GradientPenaltyLoss, self).__init__()
|
|
self.register_buffer('grad_outputs', torch.Tensor())
|
|
self.grad_outputs = self.grad_outputs.to(device)
|
|
|
|
def get_grad_outputs(self, input):
|
|
if self.grad_outputs.size() != input.size():
|
|
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
|
return self.grad_outputs
|
|
|
|
def forward(self, interp, interp_crit):
|
|
grad_outputs = self.get_grad_outputs(interp_crit)
|
|
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
|
|
grad_outputs=grad_outputs, create_graph=True,
|
|
retain_graph=True, only_inputs=True)[0]
|
|
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
|
grad_interp_norm = grad_interp.norm(2, dim=1)
|
|
|
|
loss = ((grad_interp_norm - 1)**2).mean()
|
|
return loss
|