diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 558e18f5..2b696834 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -508,16 +508,15 @@ class StyleGan2DivergenceLoss(ConfigurableLoss): real = D(real_input) divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() - gp = 0 + # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: - # Apply gradient penalty. TODO: migrate this elsewhere. from models.archs.stylegan2 import gradient_penalty gp = gradient_penalty(real_input, real) - self.last_gp_loss = gp.clone().detach().item() - self.metrics.append(("gradient_penalty", gp)) + self.metrics.append(("gradient_penalty", gp.clone().detach())) + divergence_loss = divergence_loss + gp real_input.requires_grad_(requires_grad=False) - return divergence_loss + gp + return divergence_loss class StyleGan2PathLengthLoss(ConfigurableLoss):