wrap disc grad
This commit is contained in:
parent
f85f1e21db
commit
bae18c05e6
|
@ -77,7 +77,7 @@ class SRGANModel(BaseModel):
|
|||
if self.is_train:
|
||||
self.netD = networks.define_D(opt).to(self.device)
|
||||
if self.spsr_enabled:
|
||||
self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
|
||||
self.netD_grad = networks.define_D(opt, wrap=True).to(self.device) # D_grad
|
||||
|
||||
if 'network_C' in opt.keys():
|
||||
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
||||
|
|
|
@ -140,7 +140,15 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
return netG
|
||||
|
||||
|
||||
def define_D_net(opt_net, img_sz=None):
|
||||
class GradDiscWrapper(torch.nn.Module):
|
||||
def __init__(self, m):
|
||||
super(GradDiscWrapper, self).__init__()
|
||||
self.m = m
|
||||
|
||||
def forward(self, x, lr):
|
||||
return self.m(x, lr)
|
||||
|
||||
def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
|
@ -164,15 +172,17 @@ def define_D_net(opt_net, img_sz=None):
|
|||
final_temperature_step=opt_net['final_temperature_step'])
|
||||
elif which_model == "cross_compare_vgg128":
|
||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale'])
|
||||
if wrap:
|
||||
netD = GradDiscWrapper(netD)
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
||||
# Discriminator
|
||||
def define_D(opt):
|
||||
def define_D(opt, wrap=False):
|
||||
img_sz = opt['datasets']['train']['target_size']
|
||||
opt_net = opt['network_D']
|
||||
return define_D_net(opt_net, img_sz)
|
||||
return define_D_net(opt_net, img_sz, wrap=wrap)
|
||||
|
||||
def define_fixed_D(opt):
|
||||
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
||||
|
|
Loading…
Reference in New Issue
Block a user