Enable ExtensibleTrainer to skip steps when state keys are missing

This commit is contained in:
James Betker 2020-10-21 22:22:28 -06:00
parent d1175f0de1
commit 680d635420
2 changed files with 9 additions and 1 deletions

View File

@ -187,6 +187,14 @@ class ExtensibleTrainer(BaseModel):
# Steps can opt out of early (or late) training, make sure that happens here. # Steps can opt out of early (or late) training, make sure that happens here.
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']: if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
continue continue
# Steps can choose to not execute if a state key is missing.
if 'requires' in s.step_opt.keys():
requirements_met = True
for requirement in s.step_opt['requires']:
if requirement not in state.keys():
requirements_met = False
if not requirements_met:
continue
# Only set requires_grad=True for the network being trained. # Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained() nets_to_train = s.get_networks_trained()

View File

@ -405,7 +405,7 @@ class RecurrentLoss(ConfigurableLoss):
st['_real'] = real[:, i] st['_real'] = real[:, i]
st['_fake'] = state[self.opt['fake']][:, i] st['_fake'] = state[self.opt['fake']][:, i]
subloss = self.loss(net, st) subloss = self.loss(net, st)
if isinstance(self.recurrent_weights, list); if isinstance(self.recurrent_weights, list):
subloss = subloss * self.recurrent_weights[i] subloss = subloss * self.recurrent_weights[i]
total_loss += subloss total_loss += subloss
return total_loss return total_loss