GP adjustments for stylegan2

This commit is contained in:
James Betker 2020-11-12 16:44:51 -07:00
parent fc55bdb24e
commit 566b99ca75

View File

@ -508,16 +508,15 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
gp = 0
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
# Apply gradient penalty. TODO: migrate this elsewhere.
from models.archs.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.last_gp_loss = gp.clone().detach().item()
self.metrics.append(("gradient_penalty", gp))
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False)
return divergence_loss + gp
return divergence_loss
class StyleGan2PathLengthLoss(ConfigurableLoss):