Allow feature discriminator unet to only output closest layer to feature output

This commit is contained in:
James Betker 2020-07-19 19:05:08 -06:00
parent 8a9f215653
commit 46aa776fbb
2 changed files with 16 additions and 8 deletions

View File

@ -239,7 +239,7 @@ class Discriminator_UNet(nn.Module):
class Discriminator_UNet_FeaOut(nn.Module): class Discriminator_UNet_FeaOut(nn.Module):
def __init__(self, in_nc, nf): def __init__(self, in_nc, nf, feature_mode=False):
super(Discriminator_UNet_FeaOut, self).__init__() super(Discriminator_UNet_FeaOut, self).__init__()
# [64, 128, 128] # [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
@ -269,6 +269,8 @@ class Discriminator_UNet_FeaOut(nn.Module):
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
self.feature_mode = feature_mode
def forward(self, x, output_feature_vector=False): def forward(self, x, output_feature_vector=False):
fea0 = self.conv0_0(x) fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0) fea0 = self.conv0_1(fea0)
@ -294,15 +296,21 @@ class Discriminator_UNet_FeaOut(nn.Module):
loss3 = self.collapse3(self.proc3(u3)) loss3 = self.collapse3(self.proc3(u3))
res = loss3.shape[2:] res = loss3.shape[2:]
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will if self.feature_mode:
# then know how to handle them. combined_losses = F.interpolate(loss1, scale_factor=4)
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4), else:
F.interpolate(loss2, scale_factor=2), # Compress all of the loss values into the batch dimension. The actual loss attached to this output will
F.interpolate(loss3, scale_factor=1)], dim=1) # then know how to handle them.
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
F.interpolate(loss2, scale_factor=2),
F.interpolate(loss3, scale_factor=1)], dim=1)
if output_feature_vector: if output_feature_vector:
return combined_losses.view(-1, 1), feat return combined_losses.view(-1, 1), feat
else: else:
return combined_losses.view(-1, 1) return combined_losses.view(-1, 1)
def pixgan_parameters(self): def pixgan_parameters(self):
return 3, 4 if self.feature_mode:
return 1, 4
else:
return 3, 4

View File

@ -123,7 +123,7 @@ def define_D(opt):
elif which_model == "discriminator_unet": elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea": elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf']) netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
else: else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD return netD