diff --git a/codes/models/networks.py b/codes/models/networks.py index f29d5450..78196ef1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -1,4 +1,5 @@ import torch +import logging import models.archs.SRResNet_arch as SRResNet_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch @@ -16,6 +17,8 @@ import models.archs.arch_util as arch_util import functools from collections import OrderedDict +logger = logging.getLogger('base') + # Generator def define_G(opt, net_key='network_G', scale=None): if net_key is not None: @@ -143,11 +146,11 @@ def define_G(opt, net_key='network_G', scale=None): class GradDiscWrapper(torch.nn.Module): def __init__(self, m): super(GradDiscWrapper, self).__init__() - print("Wrapping a discriminator..") + logger.info("Wrapping a discriminator..") self.m = m - def forward(self, x, lr): - return self.m(x, lr) + def forward(self, x): + return self.m(x) def define_D_net(opt_net, img_sz=None, wrap=False): which_model = opt_net['which_model_D'] @@ -156,6 +159,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv']) elif which_model == 'discriminator_vgg_128_gn': netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128) + if wrap: + netD = GradDiscWrapper(netD) elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_passthrough': @@ -173,8 +178,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False): final_temperature_step=opt_net['final_temperature_step']) elif which_model == "cross_compare_vgg128": netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale']) - if wrap: - netD = GradDiscWrapper(netD) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD