forked from mrq/DL-Art-School
Allow passthrough discriminator to have passthrough disabled from config
This commit is contained in:
parent
67139602f5
commit
af1968f9e5
|
@ -107,11 +107,13 @@ class FixupBottleneck(nn.Module):
|
|||
|
||||
class FixupResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False):
|
||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False,
|
||||
disable_passthrough=False):
|
||||
super(FixupResNet, self).__init__()
|
||||
self.num_layers = sum(layers)
|
||||
self.inplanes = 3
|
||||
self.number_skips = number_skips
|
||||
self.disable_passthrough = disable_passthrough
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||
if number_skips > 0:
|
||||
|
@ -163,6 +165,11 @@ class FixupResNet(nn.Module):
|
|||
# Or just a tuple with only the high res input (this assumes number_skips was set right).
|
||||
x = x[0]
|
||||
|
||||
if self.disable_passthrough:
|
||||
if self.number_skips > 0:
|
||||
med_skip = torch.zeros_like(med_skip)
|
||||
if self.number_skips > 1:
|
||||
lo_skip = torch.zeros_like(lo_skip)
|
||||
x = self.layer0(x)
|
||||
if self.number_skips > 0:
|
||||
x = torch.cat([x, med_skip], dim=1)
|
||||
|
|
|
@ -75,7 +75,8 @@ def define_D(opt):
|
|||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_passthrough':
|
||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
||||
number_skips=opt_net['number_skips'], use_bn=True)
|
||||
number_skips=opt_net['number_skips'], use_bn=True,
|
||||
disable_passthrough=opt_net['disable_passthrough'])
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
Loading…
Reference in New Issue
Block a user