GP adjustments for stylegan2

This commit is contained in:
James Betker 2020-11-12 16:44:51 -07:00
parent fc55bdb24e
commit 566b99ca75

View File

@ -508,16 +508,15 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
real = D(real_input) real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
gp = 0 # Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0: if self.env['step'] % self.gp_frequency == 0:
# Apply gradient penalty. TODO: migrate this elsewhere.
from models.archs.stylegan2 import gradient_penalty from models.archs.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real) gp = gradient_penalty(real_input, real)
self.last_gp_loss = gp.clone().detach().item() self.metrics.append(("gradient_penalty", gp.clone().detach()))
self.metrics.append(("gradient_penalty", gp)) divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False) real_input.requires_grad_(requires_grad=False)
return divergence_loss + gp return divergence_loss
class StyleGan2PathLengthLoss(ConfigurableLoss): class StyleGan2PathLengthLoss(ConfigurableLoss):