From 680d63542047805700b1dff4a3fb2e8932aad953 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 21 Oct 2020 22:22:28 -0600 Subject: [PATCH] Enable ExtensibleTrainer to skip steps when state keys are missing --- codes/models/ExtensibleTrainer.py | 8 ++++++++ codes/models/steps/losses.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 32b627d3..0ef0afd5 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -187,6 +187,14 @@ class ExtensibleTrainer(BaseModel): # Steps can opt out of early (or late) training, make sure that happens here. if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']: continue + # Steps can choose to not execute if a state key is missing. + if 'requires' in s.step_opt.keys(): + requirements_met = True + for requirement in s.step_opt['requires']: + if requirement not in state.keys(): + requirements_met = False + if not requirements_met: + continue # Only set requires_grad=True for the network being trained. nets_to_train = s.get_networks_trained() diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index c5419a78..9f0093e4 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -405,7 +405,7 @@ class RecurrentLoss(ConfigurableLoss): st['_real'] = real[:, i] st['_fake'] = state[self.opt['fake']][:, i] subloss = self.loss(net, st) - if isinstance(self.recurrent_weights, list); + if isinstance(self.recurrent_weights, list): subloss = subloss * self.recurrent_weights[i] total_loss += subloss return total_loss