Add gradient penalty visual debug

This commit is contained in:
James Betker 2020-12-30 09:51:59 -07:00
parent 63cf3d3126
commit 9c53314ea2
3 changed files with 34 additions and 9 deletions

View File

@ -230,7 +230,12 @@ class ExtensibleTrainer(BaseModel):
# Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items():
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
assert k not in state.keys()
class OverwrittenStateError(Exception):
def __init__(self, k, keys):
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
f'immutable and keys should not be overwritten. Current keys: {keys}')
if k in state.keys():
raise OverwrittenStateError(k, list(state.keys()))
state[k] = v
if train_step:

View File

@ -72,6 +72,14 @@ class ConfigurableLoss(nn.Module):
def forward(self, net, state):
raise NotImplementedError
def is_stateful(self) -> bool:
"""
Losses can inject into the state too. useful for when a loss computation can be used by another loss.
if this is true, the forward pass must return (loss, new_state). If false (the default), forward() only returns
the loss value.
"""
return False
def extra_metrics(self):
return self.metrics
@ -270,7 +278,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
if self.gradient_penalty:
[r.requires_grad_() for r in real]
fake = extract_params_from_state(self.opt['fake'], state)
new_state = {}
fake = [f.detach() for f in fake]
new_state = {}
if self.noise:
nreal = []
nfake = []
@ -313,12 +323,20 @@ class DiscriminatorGanLoss(ConfigurableLoss):
# Apply gradient penalty. TODO: migrate this elsewhere.
from models.stylegan.stylegan2_lucidrains import gradient_penalty
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
gp = gradient_penalty(real[0], d_real)
gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
loss = loss + gp
self.metrics.append(("gradient_penalty", gp))
# The gp_structure is a useful visual debugging tool to see what areas of the generated image the disc is paying attention to.
gpimg = (gp_structure / (torch.std(gp_structure, dim=(-1, -2), keepdim=True) * 2)) \
- torch.mean(gp_structure, dim=(-1, -2), keepdim=True) + .5
new_state['%s_%s_gp_structure_img' % (self.opt['fake'], self.opt['real'])] = gpimg
return loss
return loss, new_state
# This loss is stateful because it injects a debugging result from the GP term when enabled.
def is_stateful(self) -> bool:
return True
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an

View File

@ -123,13 +123,10 @@ class ConfigurableStep(Module):
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks.
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
new_state = {}
# Prepare a de-chunked state dict which will be used for the injectors & losses.
local_state = {}
local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
for k, v in state.items():
local_state[k] = v[grad_accum_step]
local_state.update(new_state)
local_state['train_nets'] = str(self.get_networks_trained())
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
@ -164,6 +161,11 @@ class ConfigurableStep(Module):
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
continue
if loss.is_stateful():
l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
local_state.update(lstate)
new_state.update(lstate)
else:
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
total_loss += l * self.weights[loss_name]
# Record metrics.