forked from mrq/DL-Art-School
wrap disc grad
This commit is contained in:
parent
f85f1e21db
commit
bae18c05e6
|
@ -77,7 +77,7 @@ class SRGANModel(BaseModel):
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.netD = networks.define_D(opt).to(self.device)
|
self.netD = networks.define_D(opt).to(self.device)
|
||||||
if self.spsr_enabled:
|
if self.spsr_enabled:
|
||||||
self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
|
self.netD_grad = networks.define_D(opt, wrap=True).to(self.device) # D_grad
|
||||||
|
|
||||||
if 'network_C' in opt.keys():
|
if 'network_C' in opt.keys():
|
||||||
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
||||||
|
|
|
@ -140,7 +140,15 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
return netG
|
return netG
|
||||||
|
|
||||||
|
|
||||||
def define_D_net(opt_net, img_sz=None):
|
class GradDiscWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, m):
|
||||||
|
super(GradDiscWrapper, self).__init__()
|
||||||
|
self.m = m
|
||||||
|
|
||||||
|
def forward(self, x, lr):
|
||||||
|
return self.m(x, lr)
|
||||||
|
|
||||||
|
def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
which_model = opt_net['which_model_D']
|
which_model = opt_net['which_model_D']
|
||||||
|
|
||||||
if which_model == 'discriminator_vgg_128':
|
if which_model == 'discriminator_vgg_128':
|
||||||
|
@ -164,15 +172,17 @@ def define_D_net(opt_net, img_sz=None):
|
||||||
final_temperature_step=opt_net['final_temperature_step'])
|
final_temperature_step=opt_net['final_temperature_step'])
|
||||||
elif which_model == "cross_compare_vgg128":
|
elif which_model == "cross_compare_vgg128":
|
||||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale'])
|
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale'])
|
||||||
|
if wrap:
|
||||||
|
netD = GradDiscWrapper(netD)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
||||||
# Discriminator
|
# Discriminator
|
||||||
def define_D(opt):
|
def define_D(opt, wrap=False):
|
||||||
img_sz = opt['datasets']['train']['target_size']
|
img_sz = opt['datasets']['train']['target_size']
|
||||||
opt_net = opt['network_D']
|
opt_net = opt['network_D']
|
||||||
return define_D_net(opt_net, img_sz)
|
return define_D_net(opt_net, img_sz, wrap=wrap)
|
||||||
|
|
||||||
def define_fixed_D(opt):
|
def define_fixed_D(opt):
|
||||||
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user