Add 'before' and 'after' defs to injections, steps and optimizers

This commit is contained in:
James Betker 2020-09-22 17:03:22 -06:00
parent 419f77ec19
commit f40beb5460
3 changed files with 18 additions and 2 deletions

View File

@ -180,6 +180,9 @@ class ExtensibleTrainer(BaseModel):
# Skip steps if mod_step doesn't line up.
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
continue
# Steps can opt out of early (or late) training, make sure that happens here.
if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']:
continue
# Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained()
@ -226,6 +229,8 @@ class ExtensibleTrainer(BaseModel):
if 'visuals' in self.opt['logger'].keys():
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']:
if v not in state.keys():
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
if step % self.opt['logger']['visual_debug_rate'] == 0:
for i, dbgv in enumerate(state[v]):
if dbgv.shape[1] > 3:

View File

@ -344,8 +344,7 @@ class SSGNoEmbedding(nn.Module):
x_grad = self.get_g_nopadding(x)
x = self.model_fea_conv(x)
x1 = x
x1, a1 = self.sw1(x1, True, identity=x)
x1, a1 = self.sw1(x, True)
x_grad = self.grad_conv(x_grad)
x_grad_identity = x_grad

View File

@ -76,6 +76,7 @@ class ConfigurableStep(Module):
elif self.step_opt['optimizer'] == 'novograd':
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
self.optimizers.append(opt)
# Returns all optimizers used in this step.
@ -116,6 +117,10 @@ class ConfigurableStep(Module):
# Likewise, don't do injections tagged with train unless we are not in eval.
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
continue
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
continue
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)
@ -155,6 +160,13 @@ class ConfigurableStep(Module):
# all self.optimizers.
def do_step(self):
for opt in self.optimizers:
# Optimizers can be opted out in the early stages of training.
after = opt._config['after'] if 'after' in opt._config.keys() else 0
if self.env['step'] < after:
continue
before = opt._config['before'] if 'before' in opt._config.keys() else -1
if before != -1 and self.env['step'] > before:
continue
opt.step()
def get_metrics(self):