Add gp debug (fix)
This commit is contained in:
parent
9c53314ea2
commit
b1fb82476b
|
@ -234,15 +234,18 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
||||||
loss.backward(**kwargs)
|
loss.backward(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def gradient_penalty(images, output, weight=10):
|
def gradient_penalty(images, output, weight=10, return_structured_grads=False):
|
||||||
batch_size = images.shape[0]
|
batch_size = images.shape[0]
|
||||||
gradients = torch_grad(outputs=output, inputs=images,
|
gradients = torch_grad(outputs=output, inputs=images,
|
||||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
grad_outputs=torch.ones(output.size(), device=images.device),
|
||||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||||
|
|
||||||
gradients = gradients.reshape(batch_size, -1)
|
flat_grad = gradients.reshape(batch_size, -1)
|
||||||
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
|
||||||
|
if return_structured_grads:
|
||||||
|
return penalty, gradients
|
||||||
|
else:
|
||||||
|
return penalty
|
||||||
|
|
||||||
def calc_pl_lengths(styles, images):
|
def calc_pl_lengths(styles, images):
|
||||||
device = images.device
|
device = images.device
|
||||||
|
|
Loading…
Reference in New Issue
Block a user