Allow recurrence to specified for chainedgen
This commit is contained in:
parent
fc4c064867
commit
cf8118a85b
|
@ -126,7 +126,7 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
elif which_model == 'chained_gen':
|
elif which_model == 'chained_gen':
|
||||||
netG = ChainedEmbeddingGen(depth=opt_net['depth'])
|
netG = ChainedEmbeddingGen(depth=opt_net['depth'])
|
||||||
elif which_model == 'chained_gen_structured':
|
elif which_model == 'chained_gen_structured':
|
||||||
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'])
|
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False)
|
||||||
elif which_model == 'chained_gen_structuredr2':
|
elif which_model == 'chained_gen_structuredr2':
|
||||||
netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth'])
|
netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth'])
|
||||||
elif which_model == "flownet2":
|
elif which_model == "flownet2":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user