forked from mrq/DL-Art-School
Fix initial_stride on lowdim models
This commit is contained in:
parent
532704af40
commit
48532a0a8a
|
@ -39,16 +39,18 @@ def define_G(opt, net_key='network_G'):
|
|||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step']))
|
||||
elif which_model == 'LowDimRRDBNet':
|
||||
gen_scale = scale * opt_net['initial_stride']
|
||||
rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'])
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb)
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
|
||||
elif which_model == "LowDimRRDBWithMultiHeadSwitching":
|
||||
gen_scale = scale * opt_net['initial_stride']
|
||||
switcher = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=opt_net['num_convs'], num_heads=opt_net['num_heads'],
|
||||
init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])
|
||||
rrdb = functools.partial(RRDBNet_arch.LowDimRRDBWrapper, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'],
|
||||
partial_rrdb=switcher)
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, rrdb_block_f=rrdb)
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
|
||||
elif which_model == 'PixRRDBNet':
|
||||
block_f = None
|
||||
if opt_net['attention']:
|
||||
|
|
Loading…
Reference in New Issue
Block a user