From bae18c05e6ca475b5fa5b88f9156003c683c2341 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 25 Aug 2020 17:58:20 -0600 Subject: [PATCH] wrap disc grad --- codes/models/SRGAN_model.py | 2 +- codes/models/networks.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 9a141ab1..aa0fffb2 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -77,7 +77,7 @@ class SRGANModel(BaseModel): if self.is_train: self.netD = networks.define_D(opt).to(self.device) if self.spsr_enabled: - self.netD_grad = networks.define_D(opt).to(self.device) # D_grad + self.netD_grad = networks.define_D(opt, wrap=True).to(self.device) # D_grad if 'network_C' in opt.keys(): self.netC = networks.define_G(opt, net_key='network_C').to(self.device) diff --git a/codes/models/networks.py b/codes/models/networks.py index deb26144..19387325 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -140,7 +140,15 @@ def define_G(opt, net_key='network_G', scale=None): return netG -def define_D_net(opt_net, img_sz=None): +class GradDiscWrapper(torch.nn.Module): + def __init__(self, m): + super(GradDiscWrapper, self).__init__() + self.m = m + + def forward(self, x, lr): + return self.m(x, lr) + +def define_D_net(opt_net, img_sz=None, wrap=False): which_model = opt_net['which_model_D'] if which_model == 'discriminator_vgg_128': @@ -164,15 +172,17 @@ def define_D_net(opt_net, img_sz=None): final_temperature_step=opt_net['final_temperature_step']) elif which_model == "cross_compare_vgg128": netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale']) + if wrap: + netD = GradDiscWrapper(netD) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD # Discriminator -def define_D(opt): +def define_D(opt, wrap=False): img_sz = opt['datasets']['train']['target_size'] opt_net = opt['network_D'] - return define_D_net(opt_net, img_sz) + return define_D_net(opt_net, img_sz, wrap=wrap) def define_fixed_D(opt): # Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.