diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 951f06c5..8d85bd05 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -172,9 +172,11 @@ class ExtensibleTrainer(BaseModel): # Iterate through the steps, performing them one at a time. state = self.dstate for step_num, s in enumerate(self.steps): + train_step = True # 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps. + # Note that the injection points for the step might still be required, so address this by setting train_step=False if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0: - continue + train_step = False # Steps can opt out of early (or late) training, make sure that happens here. if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']: continue @@ -187,33 +189,34 @@ class ExtensibleTrainer(BaseModel): if not requirements_met: continue - # Only set requires_grad=True for the network being trained. - nets_to_train = s.get_networks_trained() - enabled = 0 - for name, net in self.networks.items(): - net_enabled = name in nets_to_train - if net_enabled: - enabled += 1 - # Networks can opt out of training before a certain iteration by declaring 'after' in their definition. - if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']: - net_enabled = False - for p in net.parameters(): - if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"): - p.requires_grad = net_enabled - else: - p.requires_grad = False - assert enabled == len(nets_to_train) + if train_step: + # Only set requires_grad=True for the network being trained. + nets_to_train = s.get_networks_trained() + enabled = 0 + for name, net in self.networks.items(): + net_enabled = name in nets_to_train + if net_enabled: + enabled += 1 + # Networks can opt out of training before a certain iteration by declaring 'after' in their definition. + if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']: + net_enabled = False + for p in net.parameters(): + if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"): + p.requires_grad = net_enabled + else: + p.requires_grad = False + assert enabled == len(nets_to_train) - # Update experiments - [e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments] + # Update experiments + [e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments] - for o in s.get_optimizers(): - o.zero_grad() + for o in s.get_optimizers(): + o.zero_grad() # Now do a forward and backward pass for each gradient accumulation step. new_states = {} for m in range(self.mega_batch_factor): - ns = s.do_forward_backward(state, m, step_num) + ns = s.do_forward_backward(state, m, step_num, train=train_step) for k, v in ns.items(): if k not in new_states.keys(): new_states[k] = [v] @@ -226,10 +229,11 @@ class ExtensibleTrainer(BaseModel): assert k not in state.keys() state[k] = v - # And finally perform optimization. - [e.before_optimize(state) for e in self.experiments] - s.do_step(step) - [e.after_optimize(state) for e in self.experiments] + if train_step: + # And finally perform optimization. + [e.before_optimize(state) for e in self.experiments] + s.do_step(step) + [e.after_optimize(state) for e in self.experiments] # Record visual outputs for usage in debugging and testing. if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0: diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index e8cb5743..4ef123b3 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -204,6 +204,7 @@ class GeneratorGanLoss(ConfigurableLoss): pred_d_real = pred_d_real.detach() pred_g_fake = netD(*fake) d_fake_diff = pred_g_fake - torch.mean(pred_d_real) + self.metrics.append(("d_fake", torch.mean(pred_g_fake))) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + self.criterion(d_fake_diff, True)) / 2