Allow recurrence to specified for chainedgen

This commit is contained in:
James Betker 2020-10-17 08:32:29 -06:00
parent fc4c064867
commit cf8118a85b

View File

@ -126,7 +126,7 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == 'chained_gen':
netG = ChainedEmbeddingGen(depth=opt_net['depth'])
elif which_model == 'chained_gen_structured':
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'])
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False)
elif which_model == 'chained_gen_structuredr2':
netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth'])
elif which_model == "flownet2":