Get rid of printing grad names (didn't work very well..)

This commit is contained in:
James Betker 2021-11-01 18:44:05 -06:00
parent 993bd52d42
commit 3b65241b6b

View File

@ -445,7 +445,6 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
Returns:
Total norm of the parameters (viewed as a single vector).
"""
parameter_names = [pn for p, pn in zip(parameters, parameter_names) if p.grad is not None]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
@ -457,14 +456,6 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
if total_norm.isnan() or total_norm.isinf():
# Find the invalid param gradients
invalid_params = []
for name, value in zip(parameter_names, parameters):
vnorm = torch.norm(value.grad.detach(), norm_type)
if vnorm.isnan() or vnorm.isinf():
invalid_params.append(name)
print(f'!!Non-finite norm encountered for gradients of these params: {invalid_params} encountered. Norm: {total_norm}, norm_type={norm_type}')
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters: