diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 2bfdbda3..c722161e 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -339,15 +339,11 @@ class SRGANModel(BaseModel): l_d_fake_scaled.backward() if self.opt['train']['gan_type'] == 'pixgan': # randomly determine portions of the image to swap to keep the discriminator honest. - - # We're making some assumptions about the underlying pixel-discriminator here. This is a - # necessary evil for now, but if this turns out well we might want to make this configurable. - PIXDISC_CHANNELS = 3 - PIXDISC_OUTPUT_REDUCTION = 8 - disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION) + pixdisc_channels, pixdisc_output_reduction = self.netD.pixgan_parameters() + disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction) b, _, w, h = var_ref[0].shape - real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) - fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) + real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device) + fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device) SWAP_MAX_DIM = w // 4 SWAP_MIN_DIM = 16 assert SWAP_MAX_DIM > 0 diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index ddd38a05..3f110edb 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torchvision -from models.archs.arch_util import ConvBnLelu, ConvGnLelu +from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock import torch.nn.functional as F @@ -169,3 +169,74 @@ class Discriminator_VGG_PixLoss(nn.Module): combined_losses = torch.cat([loss, loss3, loss2], dim=1) return combined_losses.view(-1, 1) + def pixgan_parameters(self): + return 3, 8 + + +class Discriminator_UNet(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_VGG_PixLoss, self).__init__() + # [64, 128, 128] + self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, gn=False) + self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) + # [64, 64, 64] + self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) + self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) + # [128, 32, 32] + self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) + self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) + # [256, 16, 16] + self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False) + self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + # [512, 8, 8] + self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) + self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + + self.up1 = ExpansionBlock(nf * 8, block=ConvGnLelu) + self.proc1 = ConvGnLelu(nf * 4, nf * 4, bias=False) + self.collapse1 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False) + + self.up2 = ExpansionBlock(nf * 4, block=ConvGnLelu) + self.proc2 = ConvGnLelu(nf * 2, nf * 2, bias=False) + self.collapse2 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) + + self.up3 = ExpansionBlock(nf * 2, block=ConvGnLelu) + self.proc3 = ConvGnLelu(nf, nf, bias=False) + self.collapse3 = ConvGnLelu(nf, 1, bias=True, norm=False, activation=False) + + def forward(self, x, flatten=True): + x = x[0] + fea0 = self.conv0_0(x) + fea0 = self.conv0_1(fea0) + + fea1 = self.conv1_0(fea0) + fea1 = self.conv1_1(fea1) + + fea2 = self.conv2_0(fea1) + fea2 = self.conv2_1(fea2) + + fea3 = self.conv3_0(fea2) + fea3 = self.conv3_1(fea3) + + fea4 = self.conv4_0(fea3) + fea4 = self.conv4_1(fea4) + + # And the pyramid network! + u1 = self.up1(fea4, fea3) + loss1 = self.collapse1(self.proc1(u1)) + u2 = self.up2(u1, fea2) + loss2 = self.collapse2(self.proc2(u2)) + u3 = self.up3(u2, fea1) + loss3 = self.collapse3(self.proc3(u3)) + res = loss3.shape[2:] + + # Compress all of the loss values into the batch dimension. The actual loss attached to this output will + # then know how to handle them. + combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), + F.interpolate(loss2, scale_factor=2), + F.interpolate(loss3, scale_factor=1)], dim=1) + return combined_losses.view(-1, 1) + + def pixgan_parameters(self): + return 3, 4 + diff --git a/codes/models/networks.py b/codes/models/networks.py index 3495dd16..7507add2 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -111,6 +111,8 @@ def define_D(opt): disable_passthrough=opt_net['disable_passthrough']) elif which_model == 'discriminator_pix': netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + elif which_model == "discriminator_unet": + netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD