wrap disc grad

This commit is contained in:
James Betker 2020-08-25 17:58:20 -06:00
parent f85f1e21db
commit bae18c05e6
2 changed files with 14 additions and 4 deletions

View File

@ -77,7 +77,7 @@ class SRGANModel(BaseModel):
if self.is_train:
self.netD = networks.define_D(opt).to(self.device)
if self.spsr_enabled:
self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
self.netD_grad = networks.define_D(opt, wrap=True).to(self.device) # D_grad
if 'network_C' in opt.keys():
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)

View File

@ -140,7 +140,15 @@ def define_G(opt, net_key='network_G', scale=None):
return netG
def define_D_net(opt_net, img_sz=None):
class GradDiscWrapper(torch.nn.Module):
def __init__(self, m):
super(GradDiscWrapper, self).__init__()
self.m = m
def forward(self, x, lr):
return self.m(x, lr)
def define_D_net(opt_net, img_sz=None, wrap=False):
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
@ -164,15 +172,17 @@ def define_D_net(opt_net, img_sz=None):
final_temperature_step=opt_net['final_temperature_step'])
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale'])
if wrap:
netD = GradDiscWrapper(netD)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Discriminator
def define_D(opt):
def define_D(opt, wrap=False):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
return define_D_net(opt_net, img_sz)
return define_D_net(opt_net, img_sz, wrap=wrap)
def define_fixed_D(opt):
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.