diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 2124aab5..b1abc9d0 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -239,7 +239,7 @@ class Discriminator_UNet(nn.Module): class Discriminator_UNet_FeaOut(nn.Module): - def __init__(self, in_nc, nf): + def __init__(self, in_nc, nf, feature_mode=False): super(Discriminator_UNet_FeaOut, self).__init__() # [64, 128, 128] self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) @@ -269,6 +269,8 @@ class Discriminator_UNet_FeaOut(nn.Module): self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) + self.feature_mode = feature_mode + def forward(self, x, output_feature_vector=False): fea0 = self.conv0_0(x) fea0 = self.conv0_1(fea0) @@ -294,15 +296,21 @@ class Discriminator_UNet_FeaOut(nn.Module): loss3 = self.collapse3(self.proc3(u3)) res = loss3.shape[2:] - # Compress all of the loss values into the batch dimension. The actual loss attached to this output will - # then know how to handle them. - combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), - F.interpolate(loss2, scale_factor=2), - F.interpolate(loss3, scale_factor=1)], dim=1) + if self.feature_mode: + combined_losses = F.interpolate(loss1, scale_factor=4) + else: + # Compress all of the loss values into the batch dimension. The actual loss attached to this output will + # then know how to handle them. + combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), + F.interpolate(loss2, scale_factor=2), + F.interpolate(loss3, scale_factor=1)], dim=1) if output_feature_vector: return combined_losses.view(-1, 1), feat else: return combined_losses.view(-1, 1) def pixgan_parameters(self): - return 3, 4 \ No newline at end of file + if self.feature_mode: + return 1, 4 + else: + return 3, 4 \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 1bca11e3..84243f7e 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -123,7 +123,7 @@ def define_D(opt): elif which_model == "discriminator_unet": netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet_fea": - netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD