Add gp debug (fix)
This commit is contained in:
parent
9c53314ea2
commit
b1fb82476b
|
@ -234,15 +234,18 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
|||
loss.backward(**kwargs)
|
||||
|
||||
|
||||
def gradient_penalty(images, output, weight=10):
|
||||
def gradient_penalty(images, output, weight=10, return_structured_grads=False):
|
||||
batch_size = images.shape[0]
|
||||
gradients = torch_grad(outputs=output, inputs=images,
|
||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
gradients = gradients.reshape(batch_size, -1)
|
||||
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
||||
|
||||
flat_grad = gradients.reshape(batch_size, -1)
|
||||
penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
|
||||
if return_structured_grads:
|
||||
return penalty, gradients
|
||||
else:
|
||||
return penalty
|
||||
|
||||
def calc_pl_lengths(styles, images):
|
||||
device = images.device
|
||||
|
|
Loading…
Reference in New Issue
Block a user