resnext with groupnorm

This commit is contained in:
James Betker 2020-10-01 15:49:28 -06:00
parent 8beaa47933
commit aa4fd89018

View File

@ -13,6 +13,7 @@ import models.archs.StructuredSwitchedGenerator as ssg
import models.archs.rcan as rcan import models.archs.rcan as rcan
from collections import OrderedDict from collections import OrderedDict
import torchvision import torchvision
import functools
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -128,7 +129,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
number_skips=opt_net['number_skips'], use_bn=True, number_skips=opt_net['number_skips'], use_bn=True,
disable_passthrough=opt_net['disable_passthrough']) disable_passthrough=opt_net['disable_passthrough'])
elif which_model == 'resnext': elif which_model == 'resnext':
netD = torchvision.models.resnext50_32x4d(pretrained=True) netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
netD.load_state_dict(state_dict, strict=False)
netD.fc = torch.nn.Linear(512 * 4, 1) netD.fc = torch.nn.Linear(512 * 4, 1)
elif which_model == 'discriminator_pix': elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])