From 8beaa479335a71aa9b201fa1e4b471ce5a9e4626 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 1 Oct 2020 11:48:14 -0600 Subject: [PATCH] resnext discriminator --- codes/models/networks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/codes/models/networks.py b/codes/models/networks.py index 22dea7d7..b9c9653a 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -12,6 +12,7 @@ import models.archs.SPSR_arch as spsr import models.archs.StructuredSwitchedGenerator as ssg import models.archs.rcan as rcan from collections import OrderedDict +import torchvision logger = logging.getLogger('base') @@ -126,6 +127,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, number_skips=opt_net['number_skips'], use_bn=True, disable_passthrough=opt_net['disable_passthrough']) + elif which_model == 'resnext': + netD = torchvision.models.resnext50_32x4d(pretrained=True) + netD.fc = torch.nn.Linear(512 * 4, 1) elif which_model == 'discriminator_pix': netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet":