From f40beb5460c741227f67689f24df42c3a03c3a81 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 22 Sep 2020 17:03:22 -0600 Subject: [PATCH] Add 'before' and 'after' defs to injections, steps and optimizers --- codes/models/ExtensibleTrainer.py | 5 +++++ codes/models/archs/StructuredSwitchedGenerator.py | 3 +-- codes/models/steps/steps.py | 12 ++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 3a56cdaf..cc189a11 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -180,6 +180,9 @@ class ExtensibleTrainer(BaseModel): # Skip steps if mod_step doesn't line up. if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0: continue + # Steps can opt out of early (or late) training, make sure that happens here. + if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']: + continue # Only set requires_grad=True for the network being trained. nets_to_train = s.get_networks_trained() @@ -226,6 +229,8 @@ class ExtensibleTrainer(BaseModel): if 'visuals' in self.opt['logger'].keys(): sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") for v in self.opt['logger']['visuals']: + if v not in state.keys(): + continue # This can happen for several reasons (ex: 'after' defs), just ignore it. if step % self.opt['logger']['visual_debug_rate'] == 0: for i, dbgv in enumerate(state[v]): if dbgv.shape[1] > 3: diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 714df7a6..2cb36363 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -344,8 +344,7 @@ class SSGNoEmbedding(nn.Module): x_grad = self.get_g_nopadding(x) x = self.model_fea_conv(x) - x1 = x - x1, a1 = self.sw1(x1, True, identity=x) + x1, a1 = self.sw1(x, True) x_grad = self.grad_conv(x_grad) x_grad_identity = x_grad diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index c8c7de38..84d635b9 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -76,6 +76,7 @@ class ConfigurableStep(Module): elif self.step_opt['optimizer'] == 'novograd': opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'], betas=(opt_config['beta1'], opt_config['beta2'])) + opt._config = opt_config # This is a bit seedy, but we will need these configs later. self.optimizers.append(opt) # Returns all optimizers used in this step. @@ -116,6 +117,10 @@ class ConfigurableStep(Module): # Likewise, don't do injections tagged with train unless we are not in eval. if not train and 'train' in inj.opt.keys() and inj.opt['train']: continue + # Don't do injections tagged with 'after' or 'before' when we are out of spec. + if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ + 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: + continue injected = inj(local_state) local_state.update(injected) new_state.update(injected) @@ -155,6 +160,13 @@ class ConfigurableStep(Module): # all self.optimizers. def do_step(self): for opt in self.optimizers: + # Optimizers can be opted out in the early stages of training. + after = opt._config['after'] if 'after' in opt._config.keys() else 0 + if self.env['step'] < after: + continue + before = opt._config['before'] if 'before' in opt._config.keys() else -1 + if before != -1 and self.env['step'] > before: + continue opt.step() def get_metrics(self):