forked from mrq/DL-Art-School
GP adjustments for stylegan2
This commit is contained in:
parent
fc55bdb24e
commit
566b99ca75
|
@ -508,16 +508,15 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
|
||||||
real = D(real_input)
|
real = D(real_input)
|
||||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||||
|
|
||||||
gp = 0
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
if self.env['step'] % self.gp_frequency == 0:
|
if self.env['step'] % self.gp_frequency == 0:
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
|
||||||
from models.archs.stylegan2 import gradient_penalty
|
from models.archs.stylegan2 import gradient_penalty
|
||||||
gp = gradient_penalty(real_input, real)
|
gp = gradient_penalty(real_input, real)
|
||||||
self.last_gp_loss = gp.clone().detach().item()
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
self.metrics.append(("gradient_penalty", gp))
|
divergence_loss = divergence_loss + gp
|
||||||
|
|
||||||
real_input.requires_grad_(requires_grad=False)
|
real_input.requires_grad_(requires_grad=False)
|
||||||
return divergence_loss + gp
|
return divergence_loss
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user