diff --git a/codes/models/networks.py b/codes/models/networks.py index 44ef9a94..72ebd8b1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -39,16 +39,18 @@ def define_G(opt, net_key='network_G'): init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])) elif which_model == 'LowDimRRDBNet': + gen_scale = scale * opt_net['initial_stride'] rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim']) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb) + nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride']) elif which_model == "LowDimRRDBWithMultiHeadSwitching": + gen_scale = scale * opt_net['initial_stride'] switcher = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=opt_net['num_convs'], num_heads=opt_net['num_heads'], init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step']) rrdb = functools.partial(RRDBNet_arch.LowDimRRDBWrapper, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'], partial_rrdb=switcher) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb) + nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride']) elif which_model == 'PixRRDBNet': block_f = None if opt_net['attention']: