From 3b65241b6bcb0635f5607b065afc67267fb0aac0 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 1 Nov 2021 18:44:05 -0600 Subject: [PATCH] Get rid of printing grad names (didn't work very well..) --- codes/utils/util.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/codes/utils/util.py b/codes/utils/util.py index 024f9c1f..24961390 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -445,7 +445,6 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor Returns: Total norm of the parameters (viewed as a single vector). """ - parameter_names = [pn for p, pn in zip(parameters, parameter_names) if p.grad is not None] parameters = [p for p in parameters if p.grad is not None] max_norm = float(max_norm) norm_type = float(norm_type) @@ -457,14 +456,6 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) else: total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) - if total_norm.isnan() or total_norm.isinf(): - # Find the invalid param gradients - invalid_params = [] - for name, value in zip(parameter_names, parameters): - vnorm = torch.norm(value.grad.detach(), norm_type) - if vnorm.isnan() or vnorm.isinf(): - invalid_params.append(name) - print(f'!!Non-finite norm encountered for gradients of these params: {invalid_params} encountered. Norm: {total_norm}, norm_type={norm_type}') clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: