From b1238d29cb8bbddb9fb42d211f413b3bad6cea9b Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Sep 2020 20:31:26 -0600 Subject: [PATCH] Fix trainable not applying to discriminators --- codes/models/ExtensibleTrainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 8cdc77f1..001082a9 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -39,12 +39,12 @@ class ExtensibleTrainer(BaseModel): self.netsD = {} self.netF = networks.define_F().to(self.device) # Used to compute feature loss. for name, net in opt['networks'].items(): + # Trainable is a required parameter, but the default is simply true. Set it here. + if 'trainable' not in net.keys(): + net['trainable'] = True + if net['type'] == 'generator': new_net = networks.define_G(net, None, opt['scale']).to(self.device) - if 'trainable' not in net.keys(): - net['trainable'] = True - if not net['trainable']: - new_net.eval() self.netsG[name] = new_net elif net['type'] == 'discriminator': new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device) @@ -52,6 +52,9 @@ class ExtensibleTrainer(BaseModel): else: raise NotImplementedError("Can only handle generators and discriminators") + if not net['trainable']: + new_net.eval() + # Initialize the train/eval steps self.steps = [] for step_name, step in opt['steps'].items():