diff --git a/codes/models/networks.py b/codes/models/networks.py index 22dea7d7..b9c9653a 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -12,6 +12,7 @@ import models.archs.SPSR_arch as spsr import models.archs.StructuredSwitchedGenerator as ssg import models.archs.rcan as rcan from collections import OrderedDict +import torchvision logger = logging.getLogger('base') @@ -126,6 +127,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, number_skips=opt_net['number_skips'], use_bn=True, disable_passthrough=opt_net['disable_passthrough']) + elif which_model == 'resnext': + netD = torchvision.models.resnext50_32x4d(pretrained=True) + netD.fc = torch.nn.Linear(512 * 4, 1) elif which_model == 'discriminator_pix': netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet":