Add custom clip_grad_norm that prints out the param names in error.

This commit is contained in:
James Betker 2021-11-01 11:12:20 -06:00
parent f7d0901ce6
commit 87364b890f
3 changed files with 65 additions and 6 deletions

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -7,7 +7,7 @@ from trainer.losses import create_loss
import torch
from collections import OrderedDict
from trainer.inject import create_injector
from utils.util import recursively_detach, opt_get
from utils.util import recursively_detach, opt_get, clip_grad_norm
logger = logging.getLogger('base')
@ -111,13 +111,16 @@ class ConfigurableStep(Module):
else:
if self.env['rank'] <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))]
params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))]
params_names_notweights = sorted(list(param_names_notweights))
params_notweights = [param_map[k] for k in params_names_notweights]
params_names_weights = sorted(list(all_param_names ^ param_names_notweights))
params_weights = [param_map[k] for k in params_names_weights]
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
opt._group_names = sorted(list(all_param_names))
elif self.step_opt['optimizer'] == 'adamw':
groups = [
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
@ -126,15 +129,18 @@ class ConfigurableStep(Module):
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt._group_names = [params_names_weights, params_names_notweights]
elif self.step_opt['optimizer'] == 'lars':
from trainer.optimizers.larc import LARC
from trainer.optimizers.sgd import SGDNoBiasMomentum
optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'],
weight_decay=opt_config['weight_decay'])
opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
opt._group_names = sorted(list(all_param_names))
elif self.step_opt['optimizer'] == 'sgd':
from torch.optim import SGD
opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])
opt._group_names = sorted(list(all_param_names))
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
opt._config['network'] = net_name
self.optimizers.append(opt)
@ -288,8 +294,8 @@ class ConfigurableStep(Module):
self.nan_counter = 0
if self.clip_grad_eps is not None:
for pg in opt.param_groups:
grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps)
for pgn, pg in zip(opt._group_names, opt.param_groups):
grad_norm = clip_grad_norm(pg['params'], pgn, self.clip_grad_eps)
if torch.isnan(grad_norm):
nan_found = True
self.nan_counter += 1

View File

@ -15,6 +15,7 @@ from shutil import get_terminal_size
import scp
import paramiko
from torch.utils.checkpoint import checkpoint
from torch._six import inf
import yaml
try:
@ -417,3 +418,55 @@ def get_mask_from_lengths(lengths, max_len=None):
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask
def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
r"""
Equivalent to torch.nn.utils.clip_grad_norm_() but with the following changes:
- Takes in a dictionary of parameters (from get_named_parameters()) instead of a list of parameters.
- When NaN or inf norms are encountered, the parameter name is printed.
- error_if_nonfinite removed.
Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:``parameters`` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
Returns:
Total norm of the parameters (viewed as a single vector).
"""
parameter_names = [pn for p, pn in zip(parameters, parameter_names) if p.grad is not None]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
norms = [p.grad.detach().abs().max().to(device) for p in parameters]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
if total_norm.isnan() or total_norm.isinf():
# Find the invalid param gradients
invalid_params = []
for name, value in zip(parameter_names, parameters):
vnorm = torch.norm(value.grad.detach(), norm_type)
if vnorm.isnan() or vnorm.isinf():
invalid_params.append(name)
print(f'!!Non-finite norm encountered for gradients of these params: {invalid_params} encountered. Norm: {total_norm}, norm_type={norm_type}')
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef.to(p.grad.device))
return total_norm