Add gradient penalty visual debug

This commit is contained in:
James Betker 2020-12-30 09:51:59 -07:00
parent 63cf3d3126
commit 9c53314ea2
3 changed files with 34 additions and 9 deletions

View File

@ -230,7 +230,12 @@ class ExtensibleTrainer(BaseModel):
# Push the detached new state tensors into the state map for use with the next step. # Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items(): for k, v in new_states.items():
# State is immutable to reduce complexity. Overwriting existing state keys is not supported. # State is immutable to reduce complexity. Overwriting existing state keys is not supported.
assert k not in state.keys() class OverwrittenStateError(Exception):
def __init__(self, k, keys):
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
f'immutable and keys should not be overwritten. Current keys: {keys}')
if k in state.keys():
raise OverwrittenStateError(k, list(state.keys()))
state[k] = v state[k] = v
if train_step: if train_step:

View File

@ -72,6 +72,14 @@ class ConfigurableLoss(nn.Module):
def forward(self, net, state): def forward(self, net, state):
raise NotImplementedError raise NotImplementedError
def is_stateful(self) -> bool:
"""
Losses can inject into the state too. useful for when a loss computation can be used by another loss.
if this is true, the forward pass must return (loss, new_state). If false (the default), forward() only returns
the loss value.
"""
return False
def extra_metrics(self): def extra_metrics(self):
return self.metrics return self.metrics
@ -270,7 +278,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
if self.gradient_penalty: if self.gradient_penalty:
[r.requires_grad_() for r in real] [r.requires_grad_() for r in real]
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
new_state = {}
fake = [f.detach() for f in fake] fake = [f.detach() for f in fake]
new_state = {}
if self.noise: if self.noise:
nreal = [] nreal = []
nfake = [] nfake = []
@ -313,12 +323,20 @@ class DiscriminatorGanLoss(ConfigurableLoss):
# Apply gradient penalty. TODO: migrate this elsewhere. # Apply gradient penalty. TODO: migrate this elsewhere.
from models.stylegan.stylegan2_lucidrains import gradient_penalty from models.stylegan.stylegan2_lucidrains import gradient_penalty
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators. assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
gp = gradient_penalty(real[0], d_real) gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True)
self.metrics.append(("gradient_penalty", gp.clone().detach())) self.metrics.append(("gradient_penalty", gp.clone().detach()))
loss = loss + gp loss = loss + gp
self.metrics.append(("gradient_penalty", gp)) self.metrics.append(("gradient_penalty", gp))
# The gp_structure is a useful visual debugging tool to see what areas of the generated image the disc is paying attention to.
gpimg = (gp_structure / (torch.std(gp_structure, dim=(-1, -2), keepdim=True) * 2)) \
- torch.mean(gp_structure, dim=(-1, -2), keepdim=True) + .5
new_state['%s_%s_gp_structure_img' % (self.opt['fake'], self.opt['real'])] = gpimg
return loss return loss, new_state
# This loss is stateful because it injects a debugging result from the GP term when enabled.
def is_stateful(self) -> bool:
return True
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an

View File

@ -123,13 +123,10 @@ class ConfigurableStep(Module):
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks. # steps might use. These tensors are automatically detached and accumulated into chunks.
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True): def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
new_state = {} local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
# Prepare a de-chunked state dict which will be used for the injectors & losses.
local_state = {}
for k, v in state.items(): for k, v in state.items():
local_state[k] = v[grad_accum_step] local_state[k] = v[grad_accum_step]
local_state.update(new_state)
local_state['train_nets'] = str(self.get_networks_trained()) local_state['train_nets'] = str(self.get_networks_trained())
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env. # Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
@ -164,6 +161,11 @@ class ConfigurableStep(Module):
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \ 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0: 'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
continue continue
if loss.is_stateful():
l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
local_state.update(lstate)
new_state.update(lstate)
else:
l = loss(self.get_network_for_name(self.step_opt['training']), local_state) l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
total_loss += l * self.weights[loss_name] total_loss += l * self.weights[loss_name]
# Record metrics. # Record metrics.