forked from mrq/DL-Art-School
Add 'before' and 'after' defs to injections, steps and optimizers
This commit is contained in:
parent
419f77ec19
commit
f40beb5460
|
@ -180,6 +180,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Skip steps if mod_step doesn't line up.
|
||||
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
|
||||
continue
|
||||
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||
if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']:
|
||||
continue
|
||||
|
||||
# Only set requires_grad=True for the network being trained.
|
||||
nets_to_train = s.get_networks_trained()
|
||||
|
@ -226,6 +229,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
if 'visuals' in self.opt['logger'].keys():
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||
for v in self.opt['logger']['visuals']:
|
||||
if v not in state.keys():
|
||||
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
||||
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
for i, dbgv in enumerate(state[v]):
|
||||
if dbgv.shape[1] > 3:
|
||||
|
|
|
@ -344,8 +344,7 @@ class SSGNoEmbedding(nn.Module):
|
|||
x_grad = self.get_g_nopadding(x)
|
||||
|
||||
x = self.model_fea_conv(x)
|
||||
x1 = x
|
||||
x1, a1 = self.sw1(x1, True, identity=x)
|
||||
x1, a1 = self.sw1(x, True)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
|
|
|
@ -76,6 +76,7 @@ class ConfigurableStep(Module):
|
|||
elif self.step_opt['optimizer'] == 'novograd':
|
||||
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
||||
self.optimizers.append(opt)
|
||||
|
||||
# Returns all optimizers used in this step.
|
||||
|
@ -116,6 +117,10 @@ class ConfigurableStep(Module):
|
|||
# Likewise, don't do injections tagged with train unless we are not in eval.
|
||||
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
||||
continue
|
||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
||||
continue
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
|
@ -155,6 +160,13 @@ class ConfigurableStep(Module):
|
|||
# all self.optimizers.
|
||||
def do_step(self):
|
||||
for opt in self.optimizers:
|
||||
# Optimizers can be opted out in the early stages of training.
|
||||
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
||||
if self.env['step'] < after:
|
||||
continue
|
||||
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
||||
if before != -1 and self.env['step'] > before:
|
||||
continue
|
||||
opt.step()
|
||||
|
||||
def get_metrics(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user