Fix initial_stride on lowdim models

This commit is contained in:
James Betker 2020-06-14 11:02:16 -06:00
parent 532704af40
commit 48532a0a8a

View File

@ -39,16 +39,18 @@ def define_G(opt, net_key='network_G'):
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step']))
elif which_model == 'LowDimRRDBNet':
gen_scale = scale * opt_net['initial_stride']
rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'])
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb)
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
elif which_model == "LowDimRRDBWithMultiHeadSwitching":
gen_scale = scale * opt_net['initial_stride']
switcher = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=opt_net['num_convs'], num_heads=opt_net['num_heads'],
init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])
rrdb = functools.partial(RRDBNet_arch.LowDimRRDBWrapper, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'],
partial_rrdb=switcher)
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb)
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
elif which_model == 'PixRRDBNet':
block_f = None
if opt_net['attention']: