diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 070e8913..8cdc77f1 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -41,6 +41,10 @@ class ExtensibleTrainer(BaseModel): for name, net in opt['networks'].items(): if net['type'] == 'generator': new_net = networks.define_G(net, None, opt['scale']).to(self.device) + if 'trainable' not in net.keys(): + net['trainable'] = True + if not net['trainable']: + new_net.eval() self.netsG[name] = new_net elif net['type'] == 'discriminator': new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device) @@ -213,7 +217,7 @@ class ExtensibleTrainer(BaseModel): # Iterate through the steps, performing them one at a time. state = self.dstate for step_num, s in enumerate(self.steps): - ns = s.do_forward_backward(state, 0, step_num, backward=False) + ns = s.do_forward_backward(state, 0, step_num, train=False) for k, v in ns.items(): state[k] = [v] @@ -260,7 +264,9 @@ class ExtensibleTrainer(BaseModel): def save(self, iter_step): for name, net in self.networks.items(): - self.save_network(net, name, iter_step) + # Don't save non-trainable networks. + if self.opt['networks'][name]['trainable']: + self.save_network(net, name, iter_step) def force_restore_swapout(self): # Legacy method. Do nothing. diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 15331e5a..7af98b31 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -72,7 +72,7 @@ class ConfigurableStep(Module): # Performs all forward and backward passes for this step given an input state. All input states are lists of # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later # steps might use. These tensors are automatically detached and accumulated into chunks. - def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True): + def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True): new_state = {} # Prepare a de-chunked state dict which will be used for the injectors & losses. @@ -83,11 +83,14 @@ class ConfigurableStep(Module): # Inject in any extra dependencies. for inj in self.injectors: + # Don't do injections tagged with eval unless we are not in train mode. + if train and 'eval' in inj.opt.keys() and inj.opt['eval']: + continue injected = inj(local_state) local_state.update(injected) new_state.update(injected) - if backward: + if train: # Finally, compute the losses. total_loss = 0 for loss_name, loss in self.losses.items():