From b1fb82476b9f8efe40c31ad6b32a0babdbcf91ec Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 30 Dec 2020 15:26:54 -0700 Subject: [PATCH] Add gp debug (fix) --- codes/models/stylegan/stylegan2_lucidrains.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py index fa4040a5..1310cb58 100644 --- a/codes/models/stylegan/stylegan2_lucidrains.py +++ b/codes/models/stylegan/stylegan2_lucidrains.py @@ -234,15 +234,18 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): loss.backward(**kwargs) -def gradient_penalty(images, output, weight=10): +def gradient_penalty(images, output, weight=10, return_structured_grads=False): batch_size = images.shape[0] gradients = torch_grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size(), device=images.device), create_graph=True, retain_graph=True, only_inputs=True)[0] - gradients = gradients.reshape(batch_size, -1) - return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() - + flat_grad = gradients.reshape(batch_size, -1) + penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean() + if return_structured_grads: + return penalty, gradients + else: + return penalty def calc_pl_lengths(styles, images): device = images.device