Allow passthrough discriminator to have passthrough disabled from config

This commit is contained in:
James Betker 2020-05-19 09:41:16 -06:00
parent 67139602f5
commit af1968f9e5
2 changed files with 10 additions and 2 deletions

View File

@ -107,11 +107,13 @@ class FixupBottleneck(nn.Module):
class FixupResNet(nn.Module): class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False): def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False,
disable_passthrough=False):
super(FixupResNet, self).__init__() super(FixupResNet, self).__init__()
self.num_layers = sum(layers) self.num_layers = sum(layers)
self.inplanes = 3 self.inplanes = 3
self.number_skips = number_skips self.number_skips = number_skips
self.disable_passthrough = disable_passthrough
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5) self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
if number_skips > 0: if number_skips > 0:
@ -163,6 +165,11 @@ class FixupResNet(nn.Module):
# Or just a tuple with only the high res input (this assumes number_skips was set right). # Or just a tuple with only the high res input (this assumes number_skips was set right).
x = x[0] x = x[0]
if self.disable_passthrough:
if self.number_skips > 0:
med_skip = torch.zeros_like(med_skip)
if self.number_skips > 1:
lo_skip = torch.zeros_like(lo_skip)
x = self.layer0(x) x = self.layer0(x)
if self.number_skips > 0: if self.number_skips > 0:
x = torch.cat([x, med_skip], dim=1) x = torch.cat([x, med_skip], dim=1)

View File

@ -75,7 +75,8 @@ def define_D(opt):
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_passthrough': elif which_model == 'discriminator_resnet_passthrough':
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
number_skips=opt_net['number_skips'], use_bn=True) number_skips=opt_net['number_skips'], use_bn=True,
disable_passthrough=opt_net['disable_passthrough'])
else: else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD return netD