diff --git a/codes/models/networks.py b/codes/models/networks.py index a6841346..2c232fbc 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -126,7 +126,7 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'chained_gen': netG = ChainedEmbeddingGen(depth=opt_net['depth']) elif which_model == 'chained_gen_structured': - netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth']) + netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False) elif which_model == 'chained_gen_structuredr2': netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth']) elif which_model == "flownet2":