forked from mrq/DL-Art-School
Add custom clip_grad_norm that prints out the param names in error.
This commit is contained in:
parent
f7d0901ce6
commit
87364b890f
|
@ -284,7 +284,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from trainer.losses import create_loss
|
||||||
import torch
|
import torch
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from trainer.inject import create_injector
|
from trainer.inject import create_injector
|
||||||
from utils.util import recursively_detach, opt_get
|
from utils.util import recursively_detach, opt_get, clip_grad_norm
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -111,13 +111,16 @@ class ConfigurableStep(Module):
|
||||||
else:
|
else:
|
||||||
if self.env['rank'] <= 0:
|
if self.env['rank'] <= 0:
|
||||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))]
|
params_names_notweights = sorted(list(param_names_notweights))
|
||||||
params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))]
|
params_notweights = [param_map[k] for k in params_names_notweights]
|
||||||
|
params_names_weights = sorted(list(all_param_names ^ param_names_notweights))
|
||||||
|
params_weights = [param_map[k] for k in params_names_weights]
|
||||||
|
|
||||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||||
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
|
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
|
||||||
weight_decay=opt_config['weight_decay'],
|
weight_decay=opt_config['weight_decay'],
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
|
opt._group_names = sorted(list(all_param_names))
|
||||||
elif self.step_opt['optimizer'] == 'adamw':
|
elif self.step_opt['optimizer'] == 'adamw':
|
||||||
groups = [
|
groups = [
|
||||||
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
||||||
|
@ -126,15 +129,18 @@ class ConfigurableStep(Module):
|
||||||
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
|
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
|
||||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
|
opt._group_names = [params_names_weights, params_names_notweights]
|
||||||
elif self.step_opt['optimizer'] == 'lars':
|
elif self.step_opt['optimizer'] == 'lars':
|
||||||
from trainer.optimizers.larc import LARC
|
from trainer.optimizers.larc import LARC
|
||||||
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||||
optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'],
|
optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'],
|
||||||
weight_decay=opt_config['weight_decay'])
|
weight_decay=opt_config['weight_decay'])
|
||||||
opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
|
opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
|
||||||
|
opt._group_names = sorted(list(all_param_names))
|
||||||
elif self.step_opt['optimizer'] == 'sgd':
|
elif self.step_opt['optimizer'] == 'sgd':
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])
|
opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])
|
||||||
|
opt._group_names = sorted(list(all_param_names))
|
||||||
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
||||||
opt._config['network'] = net_name
|
opt._config['network'] = net_name
|
||||||
self.optimizers.append(opt)
|
self.optimizers.append(opt)
|
||||||
|
@ -288,8 +294,8 @@ class ConfigurableStep(Module):
|
||||||
self.nan_counter = 0
|
self.nan_counter = 0
|
||||||
|
|
||||||
if self.clip_grad_eps is not None:
|
if self.clip_grad_eps is not None:
|
||||||
for pg in opt.param_groups:
|
for pgn, pg in zip(opt._group_names, opt.param_groups):
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(pg['params'], self.clip_grad_eps)
|
grad_norm = clip_grad_norm(pg['params'], pgn, self.clip_grad_eps)
|
||||||
if torch.isnan(grad_norm):
|
if torch.isnan(grad_norm):
|
||||||
nan_found = True
|
nan_found = True
|
||||||
self.nan_counter += 1
|
self.nan_counter += 1
|
||||||
|
|
|
@ -15,6 +15,7 @@ from shutil import get_terminal_size
|
||||||
import scp
|
import scp
|
||||||
import paramiko
|
import paramiko
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from torch._six import inf
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
try:
|
try:
|
||||||
|
@ -417,3 +418,55 @@ def get_mask_from_lengths(lengths, max_len=None):
|
||||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
|
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
|
||||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Equivalent to torch.nn.utils.clip_grad_norm_() but with the following changes:
|
||||||
|
- Takes in a dictionary of parameters (from get_named_parameters()) instead of a list of parameters.
|
||||||
|
- When NaN or inf norms are encountered, the parameter name is printed.
|
||||||
|
- error_if_nonfinite removed.
|
||||||
|
|
||||||
|
Clips gradient norm of an iterable of parameters.
|
||||||
|
|
||||||
|
The norm is computed over all gradients together, as if they were
|
||||||
|
concatenated into a single vector. Gradients are modified in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||||
|
single Tensor that will have gradients normalized
|
||||||
|
max_norm (float or int): max norm of the gradients
|
||||||
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||||
|
infinity norm.
|
||||||
|
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||||
|
norm of the gradients from :attr:``parameters`` is ``nan``,
|
||||||
|
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total norm of the parameters (viewed as a single vector).
|
||||||
|
"""
|
||||||
|
parameter_names = [pn for p, pn in zip(parameters, parameter_names) if p.grad is not None]
|
||||||
|
parameters = [p for p in parameters if p.grad is not None]
|
||||||
|
max_norm = float(max_norm)
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
if len(parameters) == 0:
|
||||||
|
return torch.tensor(0.)
|
||||||
|
device = parameters[0].grad.device
|
||||||
|
if norm_type == inf:
|
||||||
|
norms = [p.grad.detach().abs().max().to(device) for p in parameters]
|
||||||
|
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||||
|
else:
|
||||||
|
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||||
|
if total_norm.isnan() or total_norm.isinf():
|
||||||
|
# Find the invalid param gradients
|
||||||
|
invalid_params = []
|
||||||
|
for name, value in zip(parameter_names, parameters):
|
||||||
|
vnorm = torch.norm(value.grad.detach(), norm_type)
|
||||||
|
if vnorm.isnan() or vnorm.isinf():
|
||||||
|
invalid_params.append(name)
|
||||||
|
print(f'!!Non-finite norm encountered for gradients of these params: {invalid_params} encountered. Norm: {total_norm}, norm_type={norm_type}')
|
||||||
|
clip_coef = max_norm / (total_norm + 1e-6)
|
||||||
|
if clip_coef < 1:
|
||||||
|
for p in parameters:
|
||||||
|
p.grad.detach().mul_(clip_coef.to(p.grad.device))
|
||||||
|
return total_norm
|
Loading…
Reference in New Issue
Block a user