forked from mrq/DL-Art-School
Enable ExtensibleTrainer to skip steps when state keys are missing
This commit is contained in:
parent
d1175f0de1
commit
680d635420
|
@ -187,6 +187,14 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
||||
continue
|
||||
# Steps can choose to not execute if a state key is missing.
|
||||
if 'requires' in s.step_opt.keys():
|
||||
requirements_met = True
|
||||
for requirement in s.step_opt['requires']:
|
||||
if requirement not in state.keys():
|
||||
requirements_met = False
|
||||
if not requirements_met:
|
||||
continue
|
||||
|
||||
# Only set requires_grad=True for the network being trained.
|
||||
nets_to_train = s.get_networks_trained()
|
||||
|
|
|
@ -405,7 +405,7 @@ class RecurrentLoss(ConfigurableLoss):
|
|||
st['_real'] = real[:, i]
|
||||
st['_fake'] = state[self.opt['fake']][:, i]
|
||||
subloss = self.loss(net, st)
|
||||
if isinstance(self.recurrent_weights, list);
|
||||
if isinstance(self.recurrent_weights, list):
|
||||
subloss = subloss * self.recurrent_weights[i]
|
||||
total_loss += subloss
|
||||
return total_loss
|
||||
|
|
Loading…
Reference in New Issue
Block a user