From 87364b890f8165998e9a93408fd38c7d8b0734ee Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 1 Nov 2021 11:12:20 -0600 Subject: [PATCH] Add custom clip_grad_norm that prints out the param names in error. --- codes/train.py | 2 +- codes/trainer/steps.py | 16 +++++++++---- codes/utils/util.py | 53 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/codes/train.py b/codes/train.py index d3f16ffe..5efd01ec 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 2052f7f8..35214cd5 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -7,7 +7,7 @@ from trainer.losses import create_loss import torch from collections import OrderedDict from trainer.inject import create_injector -from utils.util import recursively_detach, opt_get +from utils.util import recursively_detach, opt_get, clip_grad_norm logger = logging.getLogger('base') @@ -111,13 +111,16 @@ class ConfigurableStep(Module): else: if self.env['rank'] <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) - params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))] - params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))] + params_names_notweights = sorted(list(param_names_notweights)) + params_notweights = [param_map[k] for k in params_names_notweights] + params_names_weights = sorted(list(all_param_names ^ param_names_notweights)) + params_weights = [param_map[k] for k in params_names_weights] if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'], weight_decay=opt_config['weight_decay'], betas=(opt_config['beta1'], opt_config['beta2'])) + opt._group_names = sorted(list(all_param_names)) elif self.step_opt['optimizer'] == 'adamw': groups = [ { 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) }, @@ -126,15 +129,18 @@ class ConfigurableStep(Module): opt = torch.optim.AdamW(groups, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) + opt._group_names = [params_names_weights, params_names_notweights] elif self.step_opt['optimizer'] == 'lars': from trainer.optimizers.larc import LARC from trainer.optimizers.sgd import SGDNoBiasMomentum optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay']) opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient']) + opt._group_names = sorted(list(all_param_names)) elif self.step_opt['optimizer'] == 'sgd': from torch.optim import SGD opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay']) + opt._group_names = sorted(list(all_param_names)) opt._config = opt_config # This is a bit seedy, but we will need these configs later. opt._config['network'] = net_name self.optimizers.append(opt) @@ -288,8 +294,8 @@ class ConfigurableStep(Module): self.nan_counter = 0 if self.clip_grad_eps is not None: - for pg in opt.param_groups: - grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps) + for pgn, pg in zip(opt._group_names, opt.param_groups): + grad_norm = clip_grad_norm(pg['params'], pgn, self.clip_grad_eps) if torch.isnan(grad_norm): nan_found = True self.nan_counter += 1 diff --git a/codes/utils/util.py b/codes/utils/util.py index bff74c10..024f9c1f 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -15,6 +15,7 @@ from shutil import get_terminal_size import scp import paramiko from torch.utils.checkpoint import checkpoint +from torch._six import inf import yaml try: @@ -417,3 +418,55 @@ def get_mask_from_lengths(lengths, max_len=None): ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device) mask = (ids < lengths.unsqueeze(1)).bool() return mask + + +def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, norm_type: float = 2.0) -> torch.Tensor: + r""" + Equivalent to torch.nn.utils.clip_grad_norm_() but with the following changes: + - Takes in a dictionary of parameters (from get_named_parameters()) instead of a list of parameters. + - When NaN or inf norms are encountered, the parameter name is printed. + - error_if_nonfinite removed. + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:``parameters`` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + parameter_names = [pn for p, pn in zip(parameters, parameter_names) if p.grad is not None] + parameters = [p for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + norms = [p.grad.detach().abs().max().to(device) for p in parameters] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + if total_norm.isnan() or total_norm.isinf(): + # Find the invalid param gradients + invalid_params = [] + for name, value in zip(parameter_names, parameters): + vnorm = torch.norm(value.grad.detach(), norm_type) + if vnorm.isnan() or vnorm.isinf(): + invalid_params.append(name) + print(f'!!Non-finite norm encountered for gradients of these params: {invalid_params} encountered. Norm: {total_norm}, norm_type={norm_type}') + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + for p in parameters: + p.grad.detach().mul_(clip_coef.to(p.grad.device)) + return total_norm \ No newline at end of file