forked from mrq/DL-Art-School
Support injectors that run in eval only
This commit is contained in:
parent
17aa205e96
commit
0dfd8eaf3b
|
@ -41,6 +41,10 @@ class ExtensibleTrainer(BaseModel):
|
|||
for name, net in opt['networks'].items():
|
||||
if net['type'] == 'generator':
|
||||
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
||||
if 'trainable' not in net.keys():
|
||||
net['trainable'] = True
|
||||
if not net['trainable']:
|
||||
new_net.eval()
|
||||
self.netsG[name] = new_net
|
||||
elif net['type'] == 'discriminator':
|
||||
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
|
||||
|
@ -213,7 +217,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Iterate through the steps, performing them one at a time.
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
ns = s.do_forward_backward(state, 0, step_num, backward=False)
|
||||
ns = s.do_forward_backward(state, 0, step_num, train=False)
|
||||
for k, v in ns.items():
|
||||
state[k] = [v]
|
||||
|
||||
|
@ -260,7 +264,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def save(self, iter_step):
|
||||
for name, net in self.networks.items():
|
||||
self.save_network(net, name, iter_step)
|
||||
# Don't save non-trainable networks.
|
||||
if self.opt['networks'][name]['trainable']:
|
||||
self.save_network(net, name, iter_step)
|
||||
|
||||
def force_restore_swapout(self):
|
||||
# Legacy method. Do nothing.
|
||||
|
|
|
@ -72,7 +72,7 @@ class ConfigurableStep(Module):
|
|||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True):
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
|
||||
new_state = {}
|
||||
|
||||
# Prepare a de-chunked state dict which will be used for the injectors & losses.
|
||||
|
@ -83,11 +83,14 @@ class ConfigurableStep(Module):
|
|||
|
||||
# Inject in any extra dependencies.
|
||||
for inj in self.injectors:
|
||||
# Don't do injections tagged with eval unless we are not in train mode.
|
||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||
continue
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
|
||||
if backward:
|
||||
if train:
|
||||
# Finally, compute the losses.
|
||||
total_loss = 0
|
||||
for loss_name, loss in self.losses.items():
|
||||
|
|
Loading…
Reference in New Issue
Block a user