Allow passthrough discriminator to have passthrough disabled from config
This commit is contained in:
parent
67139602f5
commit
af1968f9e5
|
@ -107,11 +107,13 @@ class FixupBottleneck(nn.Module):
|
||||||
|
|
||||||
class FixupResNet(nn.Module):
|
class FixupResNet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False):
|
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False,
|
||||||
|
disable_passthrough=False):
|
||||||
super(FixupResNet, self).__init__()
|
super(FixupResNet, self).__init__()
|
||||||
self.num_layers = sum(layers)
|
self.num_layers = sum(layers)
|
||||||
self.inplanes = 3
|
self.inplanes = 3
|
||||||
self.number_skips = number_skips
|
self.number_skips = number_skips
|
||||||
|
self.disable_passthrough = disable_passthrough
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||||
if number_skips > 0:
|
if number_skips > 0:
|
||||||
|
@ -163,6 +165,11 @@ class FixupResNet(nn.Module):
|
||||||
# Or just a tuple with only the high res input (this assumes number_skips was set right).
|
# Or just a tuple with only the high res input (this assumes number_skips was set right).
|
||||||
x = x[0]
|
x = x[0]
|
||||||
|
|
||||||
|
if self.disable_passthrough:
|
||||||
|
if self.number_skips > 0:
|
||||||
|
med_skip = torch.zeros_like(med_skip)
|
||||||
|
if self.number_skips > 1:
|
||||||
|
lo_skip = torch.zeros_like(lo_skip)
|
||||||
x = self.layer0(x)
|
x = self.layer0(x)
|
||||||
if self.number_skips > 0:
|
if self.number_skips > 0:
|
||||||
x = torch.cat([x, med_skip], dim=1)
|
x = torch.cat([x, med_skip], dim=1)
|
||||||
|
|
|
@ -75,7 +75,8 @@ def define_D(opt):
|
||||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||||
elif which_model == 'discriminator_resnet_passthrough':
|
elif which_model == 'discriminator_resnet_passthrough':
|
||||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
||||||
number_skips=opt_net['number_skips'], use_bn=True)
|
number_skips=opt_net['number_skips'], use_bn=True,
|
||||||
|
disable_passthrough=opt_net['disable_passthrough'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
Loading…
Reference in New Issue
Block a user