forked from mrq/DL-Art-School
GP adjustments for stylegan2
This commit is contained in:
parent
fc55bdb24e
commit
566b99ca75
|
@ -508,16 +508,15 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
|
|||
real = D(real_input)
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
|
||||
gp = 0
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
from models.archs.stylegan2 import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.last_gp_loss = gp.clone().detach().item()
|
||||
self.metrics.append(("gradient_penalty", gp))
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
||||
real_input.requires_grad_(requires_grad=False)
|
||||
return divergence_loss + gp
|
||||
return divergence_loss
|
||||
|
||||
|
||||
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
||||
|
|
Loading…
Reference in New Issue
Block a user