forked from mrq/DL-Art-School
Add gradient penalty visual debug
This commit is contained in:
parent
63cf3d3126
commit
9c53314ea2
|
@ -230,7 +230,12 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Push the detached new state tensors into the state map for use with the next step.
|
||||
for k, v in new_states.items():
|
||||
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
||||
assert k not in state.keys()
|
||||
class OverwrittenStateError(Exception):
|
||||
def __init__(self, k, keys):
|
||||
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
|
||||
f'immutable and keys should not be overwritten. Current keys: {keys}')
|
||||
if k in state.keys():
|
||||
raise OverwrittenStateError(k, list(state.keys()))
|
||||
state[k] = v
|
||||
|
||||
if train_step:
|
||||
|
|
|
@ -72,6 +72,14 @@ class ConfigurableLoss(nn.Module):
|
|||
def forward(self, net, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def is_stateful(self) -> bool:
|
||||
"""
|
||||
Losses can inject into the state too. useful for when a loss computation can be used by another loss.
|
||||
if this is true, the forward pass must return (loss, new_state). If false (the default), forward() only returns
|
||||
the loss value.
|
||||
"""
|
||||
return False
|
||||
|
||||
def extra_metrics(self):
|
||||
return self.metrics
|
||||
|
||||
|
@ -270,7 +278,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
if self.gradient_penalty:
|
||||
[r.requires_grad_() for r in real]
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
new_state = {}
|
||||
fake = [f.detach() for f in fake]
|
||||
new_state = {}
|
||||
if self.noise:
|
||||
nreal = []
|
||||
nfake = []
|
||||
|
@ -313,12 +323,20 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||
gp = gradient_penalty(real[0], d_real)
|
||||
gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
loss = loss + gp
|
||||
self.metrics.append(("gradient_penalty", gp))
|
||||
# The gp_structure is a useful visual debugging tool to see what areas of the generated image the disc is paying attention to.
|
||||
gpimg = (gp_structure / (torch.std(gp_structure, dim=(-1, -2), keepdim=True) * 2)) \
|
||||
- torch.mean(gp_structure, dim=(-1, -2), keepdim=True) + .5
|
||||
new_state['%s_%s_gp_structure_img' % (self.opt['fake'], self.opt['real'])] = gpimg
|
||||
|
||||
return loss
|
||||
return loss, new_state
|
||||
|
||||
# This loss is stateful because it injects a debugging result from the GP term when enabled.
|
||||
def is_stateful(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||
|
|
|
@ -123,13 +123,10 @@ class ConfigurableStep(Module):
|
|||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
|
||||
new_state = {}
|
||||
|
||||
# Prepare a de-chunked state dict which will be used for the injectors & losses.
|
||||
local_state = {}
|
||||
local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
|
||||
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
|
||||
for k, v in state.items():
|
||||
local_state[k] = v[grad_accum_step]
|
||||
local_state.update(new_state)
|
||||
local_state['train_nets'] = str(self.get_networks_trained())
|
||||
|
||||
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
|
||||
|
@ -164,7 +161,12 @@ class ConfigurableStep(Module):
|
|||
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
|
||||
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
|
||||
continue
|
||||
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||
if loss.is_stateful():
|
||||
l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||
local_state.update(lstate)
|
||||
new_state.update(lstate)
|
||||
else:
|
||||
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||
total_loss += l * self.weights[loss_name]
|
||||
# Record metrics.
|
||||
if isinstance(l, torch.Tensor):
|
||||
|
|
Loading…
Reference in New Issue
Block a user