import torch import torch.nn as nn # Define GAN loss: [vanilla | lsgan] class GANLoss(nn.Module): def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type.lower() self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) def get_target_label(self, input, target_is_real): if target_is_real: return torch.empty_like(input).fill_(self.real_label_val) else: return torch.empty_like(input).fill_(self.fake_label_val) def forward(self, input, target_is_real): target_label = self.get_target_label(input, target_is_real) loss = self.loss(input, target_label) return loss