Allow feature discriminator unet to only output closest layer to feature output
This commit is contained in:
parent
8a9f215653
commit
46aa776fbb
|
@ -239,7 +239,7 @@ class Discriminator_UNet(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Discriminator_UNet_FeaOut(nn.Module):
|
class Discriminator_UNet_FeaOut(nn.Module):
|
||||||
def __init__(self, in_nc, nf):
|
def __init__(self, in_nc, nf, feature_mode=False):
|
||||||
super(Discriminator_UNet_FeaOut, self).__init__()
|
super(Discriminator_UNet_FeaOut, self).__init__()
|
||||||
# [64, 128, 128]
|
# [64, 128, 128]
|
||||||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||||
|
@ -269,6 +269,8 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
||||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.feature_mode = feature_mode
|
||||||
|
|
||||||
def forward(self, x, output_feature_vector=False):
|
def forward(self, x, output_feature_vector=False):
|
||||||
fea0 = self.conv0_0(x)
|
fea0 = self.conv0_0(x)
|
||||||
fea0 = self.conv0_1(fea0)
|
fea0 = self.conv0_1(fea0)
|
||||||
|
@ -294,15 +296,21 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
||||||
loss3 = self.collapse3(self.proc3(u3))
|
loss3 = self.collapse3(self.proc3(u3))
|
||||||
res = loss3.shape[2:]
|
res = loss3.shape[2:]
|
||||||
|
|
||||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
if self.feature_mode:
|
||||||
# then know how to handle them.
|
combined_losses = F.interpolate(loss1, scale_factor=4)
|
||||||
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
else:
|
||||||
F.interpolate(loss2, scale_factor=2),
|
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||||
F.interpolate(loss3, scale_factor=1)], dim=1)
|
# then know how to handle them.
|
||||||
|
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
||||||
|
F.interpolate(loss2, scale_factor=2),
|
||||||
|
F.interpolate(loss3, scale_factor=1)], dim=1)
|
||||||
if output_feature_vector:
|
if output_feature_vector:
|
||||||
return combined_losses.view(-1, 1), feat
|
return combined_losses.view(-1, 1), feat
|
||||||
else:
|
else:
|
||||||
return combined_losses.view(-1, 1)
|
return combined_losses.view(-1, 1)
|
||||||
|
|
||||||
def pixgan_parameters(self):
|
def pixgan_parameters(self):
|
||||||
return 3, 4
|
if self.feature_mode:
|
||||||
|
return 1, 4
|
||||||
|
else:
|
||||||
|
return 3, 4
|
|
@ -123,7 +123,7 @@ def define_D(opt):
|
||||||
elif which_model == "discriminator_unet":
|
elif which_model == "discriminator_unet":
|
||||||
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||||
elif which_model == "discriminator_unet_fea":
|
elif which_model == "discriminator_unet_fea":
|
||||||
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
Loading…
Reference in New Issue
Block a user