forked from mrq/DL-Art-School
Enable ExtensibleTrainer to skip steps when state keys are missing
This commit is contained in:
parent
d1175f0de1
commit
680d635420
|
@ -187,6 +187,14 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Steps can opt out of early (or late) training, make sure that happens here.
|
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||||
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
||||||
continue
|
continue
|
||||||
|
# Steps can choose to not execute if a state key is missing.
|
||||||
|
if 'requires' in s.step_opt.keys():
|
||||||
|
requirements_met = True
|
||||||
|
for requirement in s.step_opt['requires']:
|
||||||
|
if requirement not in state.keys():
|
||||||
|
requirements_met = False
|
||||||
|
if not requirements_met:
|
||||||
|
continue
|
||||||
|
|
||||||
# Only set requires_grad=True for the network being trained.
|
# Only set requires_grad=True for the network being trained.
|
||||||
nets_to_train = s.get_networks_trained()
|
nets_to_train = s.get_networks_trained()
|
||||||
|
|
|
@ -405,7 +405,7 @@ class RecurrentLoss(ConfigurableLoss):
|
||||||
st['_real'] = real[:, i]
|
st['_real'] = real[:, i]
|
||||||
st['_fake'] = state[self.opt['fake']][:, i]
|
st['_fake'] = state[self.opt['fake']][:, i]
|
||||||
subloss = self.loss(net, st)
|
subloss = self.loss(net, st)
|
||||||
if isinstance(self.recurrent_weights, list);
|
if isinstance(self.recurrent_weights, list):
|
||||||
subloss = subloss * self.recurrent_weights[i]
|
subloss = subloss * self.recurrent_weights[i]
|
||||||
total_loss += subloss
|
total_loss += subloss
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
Loading…
Reference in New Issue
Block a user