Add gp debug (fix)

This commit is contained in:
James Betker 2020-12-30 15:26:54 -07:00
parent 9c53314ea2
commit b1fb82476b

View File

@ -234,15 +234,18 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
loss.backward(**kwargs) loss.backward(**kwargs)
def gradient_penalty(images, output, weight=10): def gradient_penalty(images, output, weight=10, return_structured_grads=False):
batch_size = images.shape[0] batch_size = images.shape[0]
gradients = torch_grad(outputs=output, inputs=images, gradients = torch_grad(outputs=output, inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device), grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True, retain_graph=True, only_inputs=True)[0] create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.reshape(batch_size, -1) flat_grad = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
if return_structured_grads:
return penalty, gradients
else:
return penalty
def calc_pl_lengths(styles, images): def calc_pl_lengths(styles, images):
device = images.device device = images.device