From 48532a0a8a3eda63e2d4b1b3249202365eaf0fea Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 14 Jun 2020 11:02:16 -0600 Subject: [PATCH] Fix initial_stride on lowdim models --- codes/models/networks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codes/models/networks.py b/codes/models/networks.py index 44ef9a94..72ebd8b1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -39,16 +39,18 @@ def define_G(opt, net_key='network_G'): init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])) elif which_model == 'LowDimRRDBNet': + gen_scale = scale * opt_net['initial_stride'] rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim']) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb) + nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride']) elif which_model == "LowDimRRDBWithMultiHeadSwitching": + gen_scale = scale * opt_net['initial_stride'] switcher = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=opt_net['num_convs'], num_heads=opt_net['num_heads'], init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step']) rrdb = functools.partial(RRDBNet_arch.LowDimRRDBWrapper, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'], partial_rrdb=switcher) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb) + nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride']) elif which_model == 'PixRRDBNet': block_f = None if opt_net['attention']: