From aa4fd890185c875ff355215b955a02753a2f99fd Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 1 Oct 2020 15:49:28 -0600 Subject: [PATCH] resnext with groupnorm --- codes/models/networks.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/codes/models/networks.py b/codes/models/networks.py index b9c9653a..febad7c2 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -13,6 +13,7 @@ import models.archs.StructuredSwitchedGenerator as ssg import models.archs.rcan as rcan from collections import OrderedDict import torchvision +import functools logger = logging.getLogger('base') @@ -128,7 +129,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False): number_skips=opt_net['number_skips'], use_bn=True, disable_passthrough=opt_net['disable_passthrough']) elif which_model == 'resnext': - netD = torchvision.models.resnext50_32x4d(pretrained=True) + netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8)) + state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True) + netD.load_state_dict(state_dict, strict=False) netD.fc = torch.nn.Linear(512 * 4, 1) elif which_model == 'discriminator_pix': netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])