Support injectors that run in eval only

This commit is contained in:
James Betker 2020-09-05 07:59:45 -06:00
parent 17aa205e96
commit 0dfd8eaf3b
2 changed files with 13 additions and 4 deletions

View File

@ -41,6 +41,10 @@ class ExtensibleTrainer(BaseModel):
for name, net in opt['networks'].items():
if net['type'] == 'generator':
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
if 'trainable' not in net.keys():
net['trainable'] = True
if not net['trainable']:
new_net.eval()
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
@ -213,7 +217,7 @@ class ExtensibleTrainer(BaseModel):
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
ns = s.do_forward_backward(state, 0, step_num, backward=False)
ns = s.do_forward_backward(state, 0, step_num, train=False)
for k, v in ns.items():
state[k] = [v]
@ -260,6 +264,8 @@ class ExtensibleTrainer(BaseModel):
def save(self, iter_step):
for name, net in self.networks.items():
# Don't save non-trainable networks.
if self.opt['networks'][name]['trainable']:
self.save_network(net, name, iter_step)
def force_restore_swapout(self):

View File

@ -72,7 +72,7 @@ class ConfigurableStep(Module):
# Performs all forward and backward passes for this step given an input state. All input states are lists of
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks.
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True):
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
new_state = {}
# Prepare a de-chunked state dict which will be used for the injectors & losses.
@ -83,11 +83,14 @@ class ConfigurableStep(Module):
# Inject in any extra dependencies.
for inj in self.injectors:
# Don't do injections tagged with eval unless we are not in train mode.
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
continue
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)
if backward:
if train:
# Finally, compute the losses.
total_loss = 0
for loss_name, loss in self.losses.items():