diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index a3eb3a69..caab8373 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -230,7 +230,12 @@ class ExtensibleTrainer(BaseModel): # Push the detached new state tensors into the state map for use with the next step. for k, v in new_states.items(): # State is immutable to reduce complexity. Overwriting existing state keys is not supported. - assert k not in state.keys() + class OverwrittenStateError(Exception): + def __init__(self, k, keys): + super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered ' + f'immutable and keys should not be overwritten. Current keys: {keys}') + if k in state.keys(): + raise OverwrittenStateError(k, list(state.keys())) state[k] = v if train_step: diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index 4ef9e437..14995449 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -72,6 +72,14 @@ class ConfigurableLoss(nn.Module): def forward(self, net, state): raise NotImplementedError + def is_stateful(self) -> bool: + """ + Losses can inject into the state too. useful for when a loss computation can be used by another loss. + if this is true, the forward pass must return (loss, new_state). If false (the default), forward() only returns + the loss value. + """ + return False + def extra_metrics(self): return self.metrics @@ -270,7 +278,9 @@ class DiscriminatorGanLoss(ConfigurableLoss): if self.gradient_penalty: [r.requires_grad_() for r in real] fake = extract_params_from_state(self.opt['fake'], state) + new_state = {} fake = [f.detach() for f in fake] + new_state = {} if self.noise: nreal = [] nfake = [] @@ -313,12 +323,20 @@ class DiscriminatorGanLoss(ConfigurableLoss): # Apply gradient penalty. TODO: migrate this elsewhere. from models.stylegan.stylegan2_lucidrains import gradient_penalty assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators. - gp = gradient_penalty(real[0], d_real) + gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True) self.metrics.append(("gradient_penalty", gp.clone().detach())) loss = loss + gp self.metrics.append(("gradient_penalty", gp)) + # The gp_structure is a useful visual debugging tool to see what areas of the generated image the disc is paying attention to. + gpimg = (gp_structure / (torch.std(gp_structure, dim=(-1, -2), keepdim=True) * 2)) \ + - torch.mean(gp_structure, dim=(-1, -2), keepdim=True) + .5 + new_state['%s_%s_gp_structure_img' % (self.opt['fake'], self.opt['real'])] = gpimg - return loss + return loss, new_state + + # This loss is stateful because it injects a debugging result from the GP term when enabled. + def is_stateful(self) -> bool: + return True # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index a0704d9f..0e491d3d 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -123,13 +123,10 @@ class ConfigurableStep(Module): # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later # steps might use. These tensors are automatically detached and accumulated into chunks. def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True): - new_state = {} - - # Prepare a de-chunked state dict which will be used for the injectors & losses. - local_state = {} + local_state = {} # <-- Will store the entire local state to be passed to injectors & losses. + new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer. for k, v in state.items(): local_state[k] = v[grad_accum_step] - local_state.update(new_state) local_state['train_nets'] = str(self.get_networks_trained()) # Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env. @@ -164,7 +161,12 @@ class ConfigurableStep(Module): 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \ 'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0: continue - l = loss(self.get_network_for_name(self.step_opt['training']), local_state) + if loss.is_stateful(): + l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state) + local_state.update(lstate) + new_state.update(lstate) + else: + l = loss(self.get_network_for_name(self.step_opt['training']), local_state) total_loss += l * self.weights[loss_name] # Record metrics. if isinstance(l, torch.Tensor):