forked from mrq/DL-Art-School
Add gradient penalty visual debug
This commit is contained in:
parent
63cf3d3126
commit
9c53314ea2
|
@ -230,7 +230,12 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Push the detached new state tensors into the state map for use with the next step.
|
# Push the detached new state tensors into the state map for use with the next step.
|
||||||
for k, v in new_states.items():
|
for k, v in new_states.items():
|
||||||
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
||||||
assert k not in state.keys()
|
class OverwrittenStateError(Exception):
|
||||||
|
def __init__(self, k, keys):
|
||||||
|
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
|
||||||
|
f'immutable and keys should not be overwritten. Current keys: {keys}')
|
||||||
|
if k in state.keys():
|
||||||
|
raise OverwrittenStateError(k, list(state.keys()))
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
if train_step:
|
if train_step:
|
||||||
|
|
|
@ -72,6 +72,14 @@ class ConfigurableLoss(nn.Module):
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_stateful(self) -> bool:
|
||||||
|
"""
|
||||||
|
Losses can inject into the state too. useful for when a loss computation can be used by another loss.
|
||||||
|
if this is true, the forward pass must return (loss, new_state). If false (the default), forward() only returns
|
||||||
|
the loss value.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def extra_metrics(self):
|
def extra_metrics(self):
|
||||||
return self.metrics
|
return self.metrics
|
||||||
|
|
||||||
|
@ -270,7 +278,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
if self.gradient_penalty:
|
if self.gradient_penalty:
|
||||||
[r.requires_grad_() for r in real]
|
[r.requires_grad_() for r in real]
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
|
new_state = {}
|
||||||
fake = [f.detach() for f in fake]
|
fake = [f.detach() for f in fake]
|
||||||
|
new_state = {}
|
||||||
if self.noise:
|
if self.noise:
|
||||||
nreal = []
|
nreal = []
|
||||||
nfake = []
|
nfake = []
|
||||||
|
@ -313,12 +323,20 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||||
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||||
gp = gradient_penalty(real[0], d_real)
|
gp, gp_structure = gradient_penalty(real[0], d_real, return_structured_grads=True)
|
||||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
loss = loss + gp
|
loss = loss + gp
|
||||||
self.metrics.append(("gradient_penalty", gp))
|
self.metrics.append(("gradient_penalty", gp))
|
||||||
|
# The gp_structure is a useful visual debugging tool to see what areas of the generated image the disc is paying attention to.
|
||||||
|
gpimg = (gp_structure / (torch.std(gp_structure, dim=(-1, -2), keepdim=True) * 2)) \
|
||||||
|
- torch.mean(gp_structure, dim=(-1, -2), keepdim=True) + .5
|
||||||
|
new_state['%s_%s_gp_structure_img' % (self.opt['fake'], self.opt['real'])] = gpimg
|
||||||
|
|
||||||
return loss
|
return loss, new_state
|
||||||
|
|
||||||
|
# This loss is stateful because it injects a debugging result from the GP term when enabled.
|
||||||
|
def is_stateful(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||||
|
|
|
@ -123,13 +123,10 @@ class ConfigurableStep(Module):
|
||||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
|
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
|
||||||
new_state = {}
|
local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
|
||||||
|
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
|
||||||
# Prepare a de-chunked state dict which will be used for the injectors & losses.
|
|
||||||
local_state = {}
|
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
local_state[k] = v[grad_accum_step]
|
local_state[k] = v[grad_accum_step]
|
||||||
local_state.update(new_state)
|
|
||||||
local_state['train_nets'] = str(self.get_networks_trained())
|
local_state['train_nets'] = str(self.get_networks_trained())
|
||||||
|
|
||||||
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
|
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
|
||||||
|
@ -164,7 +161,12 @@ class ConfigurableStep(Module):
|
||||||
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
|
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
|
||||||
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
|
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
|
||||||
continue
|
continue
|
||||||
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
if loss.is_stateful():
|
||||||
|
l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||||
|
local_state.update(lstate)
|
||||||
|
new_state.update(lstate)
|
||||||
|
else:
|
||||||
|
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||||
total_loss += l * self.weights[loss_name]
|
total_loss += l * self.weights[loss_name]
|
||||||
# Record metrics.
|
# Record metrics.
|
||||||
if isinstance(l, torch.Tensor):
|
if isinstance(l, torch.Tensor):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user