forked from mrq/DL-Art-School
When skipping steps via "every", still run nontrainable injection points
This commit is contained in:
parent
91d27372e4
commit
b742d1e5a5
|
@ -172,9 +172,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Iterate through the steps, performing them one at a time.
|
# Iterate through the steps, performing them one at a time.
|
||||||
state = self.dstate
|
state = self.dstate
|
||||||
for step_num, s in enumerate(self.steps):
|
for step_num, s in enumerate(self.steps):
|
||||||
|
train_step = True
|
||||||
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
|
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
|
||||||
|
# Note that the injection points for the step might still be required, so address this by setting train_step=False
|
||||||
if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0:
|
if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0:
|
||||||
continue
|
train_step = False
|
||||||
# Steps can opt out of early (or late) training, make sure that happens here.
|
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||||
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
||||||
continue
|
continue
|
||||||
|
@ -187,33 +189,34 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if not requirements_met:
|
if not requirements_met:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Only set requires_grad=True for the network being trained.
|
if train_step:
|
||||||
nets_to_train = s.get_networks_trained()
|
# Only set requires_grad=True for the network being trained.
|
||||||
enabled = 0
|
nets_to_train = s.get_networks_trained()
|
||||||
for name, net in self.networks.items():
|
enabled = 0
|
||||||
net_enabled = name in nets_to_train
|
for name, net in self.networks.items():
|
||||||
if net_enabled:
|
net_enabled = name in nets_to_train
|
||||||
enabled += 1
|
if net_enabled:
|
||||||
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
|
enabled += 1
|
||||||
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
|
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
|
||||||
net_enabled = False
|
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
|
||||||
for p in net.parameters():
|
net_enabled = False
|
||||||
if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"):
|
for p in net.parameters():
|
||||||
p.requires_grad = net_enabled
|
if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"):
|
||||||
else:
|
p.requires_grad = net_enabled
|
||||||
p.requires_grad = False
|
else:
|
||||||
assert enabled == len(nets_to_train)
|
p.requires_grad = False
|
||||||
|
assert enabled == len(nets_to_train)
|
||||||
|
|
||||||
# Update experiments
|
# Update experiments
|
||||||
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
|
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
|
||||||
|
|
||||||
for o in s.get_optimizers():
|
for o in s.get_optimizers():
|
||||||
o.zero_grad()
|
o.zero_grad()
|
||||||
|
|
||||||
# Now do a forward and backward pass for each gradient accumulation step.
|
# Now do a forward and backward pass for each gradient accumulation step.
|
||||||
new_states = {}
|
new_states = {}
|
||||||
for m in range(self.mega_batch_factor):
|
for m in range(self.mega_batch_factor):
|
||||||
ns = s.do_forward_backward(state, m, step_num)
|
ns = s.do_forward_backward(state, m, step_num, train=train_step)
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
if k not in new_states.keys():
|
if k not in new_states.keys():
|
||||||
new_states[k] = [v]
|
new_states[k] = [v]
|
||||||
|
@ -226,10 +229,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
assert k not in state.keys()
|
assert k not in state.keys()
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
# And finally perform optimization.
|
if train_step:
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
# And finally perform optimization.
|
||||||
s.do_step(step)
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
s.do_step(step)
|
||||||
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
|
|
|
@ -204,6 +204,7 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
pred_d_real = pred_d_real.detach()
|
pred_d_real = pred_d_real.detach()
|
||||||
pred_g_fake = netD(*fake)
|
pred_g_fake = netD(*fake)
|
||||||
d_fake_diff = pred_g_fake - torch.mean(pred_d_real)
|
d_fake_diff = pred_g_fake - torch.mean(pred_d_real)
|
||||||
|
self.metrics.append(("d_fake", torch.mean(pred_g_fake)))
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
self.criterion(d_fake_diff, True)) / 2
|
self.criterion(d_fake_diff, True)) / 2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user