When skipping steps via "every", still run nontrainable injection points

This commit is contained in:
James Betker 2020-11-10 16:09:17 -07:00
parent 91d27372e4
commit b742d1e5a5
2 changed files with 31 additions and 26 deletions

View File

@ -172,9 +172,11 @@ class ExtensibleTrainer(BaseModel):
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
train_step = True
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
# Note that the injection points for the step might still be required, so address this by setting train_step=False
if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0:
continue
train_step = False
# Steps can opt out of early (or late) training, make sure that happens here.
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
continue
@ -187,6 +189,7 @@ class ExtensibleTrainer(BaseModel):
if not requirements_met:
continue
if train_step:
# Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained()
enabled = 0
@ -213,7 +216,7 @@ class ExtensibleTrainer(BaseModel):
# Now do a forward and backward pass for each gradient accumulation step.
new_states = {}
for m in range(self.mega_batch_factor):
ns = s.do_forward_backward(state, m, step_num)
ns = s.do_forward_backward(state, m, step_num, train=train_step)
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v]
@ -226,6 +229,7 @@ class ExtensibleTrainer(BaseModel):
assert k not in state.keys()
state[k] = v
if train_step:
# And finally perform optimization.
[e.before_optimize(state) for e in self.experiments]
s.do_step(step)

View File

@ -204,6 +204,7 @@ class GeneratorGanLoss(ConfigurableLoss):
pred_d_real = pred_d_real.detach()
pred_g_fake = netD(*fake)
d_fake_diff = pred_g_fake - torch.mean(pred_d_real)
self.metrics.append(("d_fake", torch.mean(pred_g_fake)))
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
self.criterion(d_fake_diff, True)) / 2