Get rid of printing grad names (didn't work very well..)
This commit is contained in:
parent
993bd52d42
commit
3b65241b6b
|
@ -445,7 +445,6 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
|
|||
Returns:
|
||||
Total norm of the parameters (viewed as a single vector).
|
||||
"""
|
||||
parameter_names = [pn for p, pn in zip(parameters, parameter_names) if p.grad is not None]
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
|
@ -457,14 +456,6 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
|
|||
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||
else:
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
if total_norm.isnan() or total_norm.isinf():
|
||||
# Find the invalid param gradients
|
||||
invalid_params = []
|
||||
for name, value in zip(parameter_names, parameters):
|
||||
vnorm = torch.norm(value.grad.detach(), norm_type)
|
||||
if vnorm.isnan() or vnorm.isinf():
|
||||
invalid_params.append(name)
|
||||
print(f'!!Non-finite norm encountered for gradients of these params: {invalid_params} encountered. Norm: {total_norm}, norm_type={norm_type}')
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1:
|
||||
for p in parameters:
|
||||
|
|
Loading…
Reference in New Issue
Block a user