diff --git a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py index 462261f7..236b1647 100644 --- a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py +++ b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py @@ -107,11 +107,13 @@ class FixupBottleneck(nn.Module): class FixupResNet(nn.Module): - def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False): + def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False, + disable_passthrough=False): super(FixupResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = 3 self.number_skips = number_skips + self.disable_passthrough = disable_passthrough self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5) if number_skips > 0: @@ -163,6 +165,11 @@ class FixupResNet(nn.Module): # Or just a tuple with only the high res input (this assumes number_skips was set right). x = x[0] + if self.disable_passthrough: + if self.number_skips > 0: + med_skip = torch.zeros_like(med_skip) + if self.number_skips > 1: + lo_skip = torch.zeros_like(lo_skip) x = self.layer0(x) if self.number_skips > 0: x = torch.cat([x, med_skip], dim=1) diff --git a/codes/models/networks.py b/codes/models/networks.py index 446322bc..98c17b46 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -75,7 +75,8 @@ def define_D(opt): netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_passthrough': netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, - number_skips=opt_net['number_skips'], use_bn=True) + number_skips=opt_net['number_skips'], use_bn=True, + disable_passthrough=opt_net['disable_passthrough']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD