diff --git a/codes/models/networks.py b/codes/models/networks.py index 5f28cf72..22dea7d7 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -120,6 +120,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = GradDiscWrapper(netD) elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) + elif which_model == 'discriminator_resnet_50': + netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_passthrough': netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, number_skips=opt_net['number_skips'], use_bn=True,