Add custom clip_grad_norm that prints out the param names in error.
This commit is contained in:
parent
f7d0901ce6
commit
87364b890f
|
@ -284,7 +284,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -7,7 +7,7 @@ from trainer.losses import create_loss
|
|||
import torch
|
||||
from collections import OrderedDict
|
||||
from trainer.inject import create_injector
|
||||
from utils.util import recursively_detach, opt_get
|
||||
from utils.util import recursively_detach, opt_get, clip_grad_norm
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -111,13 +111,16 @@ class ConfigurableStep(Module):
|
|||
else:
|
||||
if self.env['rank'] <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))]
|
||||
params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))]
|
||||
params_names_notweights = sorted(list(param_names_notweights))
|
||||
params_notweights = [param_map[k] for k in params_names_notweights]
|
||||
params_names_weights = sorted(list(all_param_names ^ param_names_notweights))
|
||||
params_weights = [param_map[k] for k in params_names_weights]
|
||||
|
||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
|
||||
weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
opt._group_names = sorted(list(all_param_names))
|
||||
elif self.step_opt['optimizer'] == 'adamw':
|
||||
groups = [
|
||||
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
||||
|
@ -126,15 +129,18 @@ class ConfigurableStep(Module):
|
|||
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
|
||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||
opt._group_names = [params_names_weights, params_names_notweights]
|
||||
elif self.step_opt['optimizer'] == 'lars':
|
||||
from trainer.optimizers.larc import LARC
|
||||
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||
optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'],
|
||||
weight_decay=opt_config['weight_decay'])
|
||||
opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
|
||||
opt._group_names = sorted(list(all_param_names))
|
||||
elif self.step_opt['optimizer'] == 'sgd':
|
||||
from torch.optim import SGD
|
||||
opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])
|
||||
opt._group_names = sorted(list(all_param_names))
|
||||
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
||||
opt._config['network'] = net_name
|
||||
self.optimizers.append(opt)
|
||||
|
@ -288,8 +294,8 @@ class ConfigurableStep(Module):
|
|||
self.nan_counter = 0
|
||||
|
||||
if self.clip_grad_eps is not None:
|
||||
for pg in opt.param_groups:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps)
|
||||
for pgn, pg in zip(opt._group_names, opt.param_groups):
|
||||
grad_norm = clip_grad_norm(pg['params'], pgn, self.clip_grad_eps)
|
||||
if torch.isnan(grad_norm):
|
||||
nan_found = True
|
||||
self.nan_counter += 1
|
||||
|
|
|
@ -15,6 +15,7 @@ from shutil import get_terminal_size
|
|||
import scp
|
||||
import paramiko
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torch._six import inf
|
||||
|
||||
import yaml
|
||||
try:
|
||||
|
@ -417,3 +418,55 @@ def get_mask_from_lengths(lengths, max_len=None):
|
|||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
||||
|
||||
def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
|
||||
r"""
|
||||
Equivalent to torch.nn.utils.clip_grad_norm_() but with the following changes:
|
||||
- Takes in a dictionary of parameters (from get_named_parameters()) instead of a list of parameters.
|
||||
- When NaN or inf norms are encountered, the parameter name is printed.
|
||||
- error_if_nonfinite removed.
|
||||
|
||||
Clips gradient norm of an iterable of parameters.
|
||||
|
||||
The norm is computed over all gradients together, as if they were
|
||||
concatenated into a single vector. Gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of the gradients from :attr:``parameters`` is ``nan``,
|
||||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters (viewed as a single vector).
|
||||
"""
|
||||
parameter_names = [pn for p, pn in zip(parameters, parameter_names) if p.grad is not None]
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if len(parameters) == 0:
|
||||
return torch.tensor(0.)
|
||||
device = parameters[0].grad.device
|
||||
if norm_type == inf:
|
||||
norms = [p.grad.detach().abs().max().to(device) for p in parameters]
|
||||
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||
else:
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
if total_norm.isnan() or total_norm.isinf():
|
||||
# Find the invalid param gradients
|
||||
invalid_params = []
|
||||
for name, value in zip(parameter_names, parameters):
|
||||
vnorm = torch.norm(value.grad.detach(), norm_type)
|
||||
if vnorm.isnan() or vnorm.isinf():
|
||||
invalid_params.append(name)
|
||||
print(f'!!Non-finite norm encountered for gradients of these params: {invalid_params} encountered. Norm: {total_norm}, norm_type={norm_type}')
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1:
|
||||
for p in parameters:
|
||||
p.grad.detach().mul_(clip_coef.to(p.grad.device))
|
||||
return total_norm
|
Loading…
Reference in New Issue
Block a user