From 0b7193392fb11e611a66c64efcb3cf90477252e4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 10 Jul 2020 16:16:03 -0600 Subject: [PATCH] Implement unet disc The latest discriminator architecture was already pretty much a unet. This one makes that official and uses shared layers. It also upsamples one additional time and throws out the lowest upsampling result. The intent is to delete the old vgg pixdisc, but I'll keep it around for a bit since I'm still trying out a few models with it. --- codes/models/SRGAN_model.py | 12 ++-- codes/models/archs/discriminator_vgg_arch.py | 73 +++++++++++++++++++- codes/models/networks.py | 2 + 3 files changed, 78 insertions(+), 9 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 2bfdbda3..c722161e 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -339,15 +339,11 @@ class SRGANModel(BaseModel): l_d_fake_scaled.backward() if self.opt['train']['gan_type'] == 'pixgan': # randomly determine portions of the image to swap to keep the discriminator honest. - - # We're making some assumptions about the underlying pixel-discriminator here. This is a - # necessary evil for now, but if this turns out well we might want to make this configurable. - PIXDISC_CHANNELS = 3 - PIXDISC_OUTPUT_REDUCTION = 8 - disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION) + pixdisc_channels, pixdisc_output_reduction = self.netD.pixgan_parameters() + disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction) b, _, w, h = var_ref[0].shape - real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) - fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) + real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device) + fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device) SWAP_MAX_DIM = w // 4 SWAP_MIN_DIM = 16 assert SWAP_MAX_DIM > 0 diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index ddd38a05..3f110edb 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torchvision -from models.archs.arch_util import ConvBnLelu, ConvGnLelu +from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock import torch.nn.functional as F @@ -169,3 +169,74 @@ class Discriminator_VGG_PixLoss(nn.Module): combined_losses = torch.cat([loss, loss3, loss2], dim=1) return combined_losses.view(-1, 1) + def pixgan_parameters(self): + return 3, 8 + + +class Discriminator_UNet(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_VGG_PixLoss, self).__init__() + # [64, 128, 128] + self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, gn=False) + self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) + # [64, 64, 64] + self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) + self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) + # [128, 32, 32] + self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) + self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) + # [256, 16, 16] + self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False) + self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + # [512, 8, 8] + self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) + self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + + self.up1 = ExpansionBlock(nf * 8, block=ConvGnLelu) + self.proc1 = ConvGnLelu(nf * 4, nf * 4, bias=False) + self.collapse1 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False) + + self.up2 = ExpansionBlock(nf * 4, block=ConvGnLelu) + self.proc2 = ConvGnLelu(nf * 2, nf * 2, bias=False) + self.collapse2 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) + + self.up3 = ExpansionBlock(nf * 2, block=ConvGnLelu) + self.proc3 = ConvGnLelu(nf, nf, bias=False) + self.collapse3 = ConvGnLelu(nf, 1, bias=True, norm=False, activation=False) + + def forward(self, x, flatten=True): + x = x[0] + fea0 = self.conv0_0(x) + fea0 = self.conv0_1(fea0) + + fea1 = self.conv1_0(fea0) + fea1 = self.conv1_1(fea1) + + fea2 = self.conv2_0(fea1) + fea2 = self.conv2_1(fea2) + + fea3 = self.conv3_0(fea2) + fea3 = self.conv3_1(fea3) + + fea4 = self.conv4_0(fea3) + fea4 = self.conv4_1(fea4) + + # And the pyramid network! + u1 = self.up1(fea4, fea3) + loss1 = self.collapse1(self.proc1(u1)) + u2 = self.up2(u1, fea2) + loss2 = self.collapse2(self.proc2(u2)) + u3 = self.up3(u2, fea1) + loss3 = self.collapse3(self.proc3(u3)) + res = loss3.shape[2:] + + # Compress all of the loss values into the batch dimension. The actual loss attached to this output will + # then know how to handle them. + combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), + F.interpolate(loss2, scale_factor=2), + F.interpolate(loss3, scale_factor=1)], dim=1) + return combined_losses.view(-1, 1) + + def pixgan_parameters(self): + return 3, 4 + diff --git a/codes/models/networks.py b/codes/models/networks.py index 3495dd16..7507add2 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -111,6 +111,8 @@ def define_D(opt): disable_passthrough=opt_net['disable_passthrough']) elif which_model == 'discriminator_pix': netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + elif which_model == "discriminator_unet": + netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD