From 8d061a268783c163e9de790be141a04768f4661f Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Jul 2020 10:10:09 -0600 Subject: [PATCH] Add u-net discriminator with feature output --- codes/models/SRGAN_model.py | 21 +++++- codes/models/archs/discriminator_vgg_arch.py | 69 ++++++++++++++++++++ codes/models/loss.py | 4 +- codes/models/networks.py | 2 + codes/train.py | 2 +- codes/utils/numeric_stability.py | 9 ++- 6 files changed, 100 insertions(+), 7 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index a47d007d..0a1a4cf5 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -271,7 +271,7 @@ class SRGANModel(BaseModel): # it should target this value. if self.l_gan_w > 0: - if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan': + if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: pred_g_fake = self.netD(fake_GenOut) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': @@ -324,6 +324,17 @@ class SRGANModel(BaseModel): # Apply noise to the inputs to slow discriminator convergence. var_ref = var_ref + noise fake_H = fake_H + noise + l_d_fea_real = torch.zeros(1) + l_d_fea_fake = torch.zeros(1) + if self.opt['train']['gan_type'] == 'pixgan_fea': + # Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better. + disc_fea_scale = .5 + _, fea_real = self.netD(var_ref, output_feature_vector=True) + actual_fea = self.netF(var_ref) + l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor + _, fea_fake = self.netD(fake_H, output_feature_vector=True) + actual_fea = self.netF(fake_H) + l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real @@ -338,7 +349,7 @@ class SRGANModel(BaseModel): l_d_fake_log = l_d_fake * self.mega_batch_factor with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() - if self.opt['train']['gan_type'] == 'pixgan': + if 'pixgan' in self.opt['train']['gan_type']: # randomly determine portions of the image to swap to keep the discriminator honest. pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) @@ -379,12 +390,14 @@ class SRGANModel(BaseModel): pred_d_real = self.netD(var_ref) l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor + l_d_real += l_d_fea_real with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() # fake pred_d_fake = self.netD(fake_H) l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor + l_d_fake += l_d_fea_fake with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -470,6 +483,10 @@ class SRGANModel(BaseModel): if self.l_gan_w > 0 and step > self.G_warmup: self.add_log_entry('l_d_real', l_d_real_log.item()) self.add_log_entry('l_d_fake', l_d_fake_log.item()) + self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor) + self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor) + self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) + self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index e20f7e9f..2124aab5 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -237,3 +237,72 @@ class Discriminator_UNet(nn.Module): def pixgan_parameters(self): return 3, 4 + +class Discriminator_UNet_FeaOut(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_UNet_FeaOut, self).__init__() + # [64, 128, 128] + self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) + self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) + # [64, 64, 64] + self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) + self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) + # [128, 32, 32] + self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) + self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) + # [256, 16, 16] + self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False) + self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + # [512, 8, 8] + self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) + self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + + self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) + self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False) + self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False) + + self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu) + self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False) + self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False) + + self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu) + self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) + self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) + + def forward(self, x, output_feature_vector=False): + fea0 = self.conv0_0(x) + fea0 = self.conv0_1(fea0) + + fea1 = self.conv1_0(fea0) + fea1 = self.conv1_1(fea1) + + fea2 = self.conv2_0(fea1) + fea2 = self.conv2_1(fea2) + + fea3 = self.conv3_0(fea2) + fea3 = self.conv3_1(fea3) + + feat = self.conv4_0(fea3) + fea4 = self.conv4_1(feat) + + # And the pyramid network! + u1 = self.up1(fea4, fea3) + loss1 = self.collapse1(self.proc1(u1)) + u2 = self.up2(u1, fea2) + loss2 = self.collapse2(self.proc2(u2)) + u3 = self.up3(u2, fea1) + loss3 = self.collapse3(self.proc3(u3)) + res = loss3.shape[2:] + + # Compress all of the loss values into the batch dimension. The actual loss attached to this output will + # then know how to handle them. + combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), + F.interpolate(loss2, scale_factor=2), + F.interpolate(loss3, scale_factor=1)], dim=1) + if output_feature_vector: + return combined_losses.view(-1, 1), feat + else: + return combined_losses.view(-1, 1) + + def pixgan_parameters(self): + return 3, 4 \ No newline at end of file diff --git a/codes/models/loss.py b/codes/models/loss.py index 342aad38..9334f806 100644 --- a/codes/models/loss.py +++ b/codes/models/loss.py @@ -23,7 +23,7 @@ class GANLoss(nn.Module): self.real_label_val = real_label_val self.fake_label_val = fake_label_val - if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan': + if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan' or self.gan_type == "pixgan_fea": self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() @@ -46,7 +46,7 @@ class GANLoss(nn.Module): return torch.empty_like(input).fill_(self.fake_label_val) def forward(self, input, target_is_real): - if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool): + if 'pixgan' in self.gan_type and not isinstance(target_is_real, bool): target_label = target_is_real else: target_label = self.get_target_label(input, target_is_real) diff --git a/codes/models/networks.py b/codes/models/networks.py index 2dfe0af0..c1732584 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -122,6 +122,8 @@ def define_D(opt): netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet": netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + elif which_model == "discriminator_unet_fea": + netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/train.py b/codes/train.py index 65395714..89b36d91 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_dual_srg.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_fdisc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py index dda9f79b..588555b6 100644 --- a/codes/utils/numeric_stability.py +++ b/codes/utils/numeric_stability.py @@ -3,6 +3,7 @@ from torch import nn import models.archs.SRG1_arch as srg1 import models.archs.SwitchedResidualGenerator_arch as srg import models.archs.NestedSwitchGenerator as nsg +import models.archs.discriminator_vgg_arch as disc import functools blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax] @@ -93,6 +94,7 @@ if __name__ == "__main__": torch.randn(1, 3, 64, 64), device='cuda') ''' + ''' test_stability(functools.partial(srg.DualOutputSRG, switch_depth=4, switch_filters=64, @@ -105,7 +107,7 @@ if __name__ == "__main__": upsample_factor=4), torch.randn(1, 3, 32, 32), device='cpu') - + ''' ''' test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator, switch_filters=[32,32,32,32], @@ -125,4 +127,7 @@ if __name__ == "__main__": 64, 16), torch.randn(1, 3, 64, 64), device='cuda') - ''' \ No newline at end of file + ''' + test_stability(functools.partial(disc.Discriminator_UNet_FeaOut, 3, 64), + torch.randn(1,3,128,128), + device='cpu') \ No newline at end of file