Allow feature discriminator unet to only output closest layer to feature output
This commit is contained in:
parent
8a9f215653
commit
46aa776fbb
|
@ -239,7 +239,7 @@ class Discriminator_UNet(nn.Module):
|
|||
|
||||
|
||||
class Discriminator_UNet_FeaOut(nn.Module):
|
||||
def __init__(self, in_nc, nf):
|
||||
def __init__(self, in_nc, nf, feature_mode=False):
|
||||
super(Discriminator_UNet_FeaOut, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||
|
@ -269,6 +269,8 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
|||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.feature_mode = feature_mode
|
||||
|
||||
def forward(self, x, output_feature_vector=False):
|
||||
fea0 = self.conv0_0(x)
|
||||
fea0 = self.conv0_1(fea0)
|
||||
|
@ -294,15 +296,21 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
|||
loss3 = self.collapse3(self.proc3(u3))
|
||||
res = loss3.shape[2:]
|
||||
|
||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||
# then know how to handle them.
|
||||
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
||||
F.interpolate(loss2, scale_factor=2),
|
||||
F.interpolate(loss3, scale_factor=1)], dim=1)
|
||||
if self.feature_mode:
|
||||
combined_losses = F.interpolate(loss1, scale_factor=4)
|
||||
else:
|
||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||
# then know how to handle them.
|
||||
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
||||
F.interpolate(loss2, scale_factor=2),
|
||||
F.interpolate(loss3, scale_factor=1)], dim=1)
|
||||
if output_feature_vector:
|
||||
return combined_losses.view(-1, 1), feat
|
||||
else:
|
||||
return combined_losses.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 3, 4
|
||||
if self.feature_mode:
|
||||
return 1, 4
|
||||
else:
|
||||
return 3, 4
|
|
@ -123,7 +123,7 @@ def define_D(opt):
|
|||
elif which_model == "discriminator_unet":
|
||||
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
elif which_model == "discriminator_unet_fea":
|
||||
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
Loading…
Reference in New Issue
Block a user