import torch import warnings from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm import torch.distributed as dist import math @torch.no_grad() def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor: def grad_exists(p): return p is not None and getattr(p, "grad", None) is not None if isinstance(params, torch.Tensor): params = [params] params = list(params) params = list(filter(grad_exists, params)) grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], [] denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count)) for p in params: if hasattr(p, "expert"): expert_grads.append(p.grad.detach() / denom) elif hasattr(p, "base_expert"): base_expert_grads.append(p.grad.detach()) elif hasattr(p, "_is_sharded"): sharded_grads.append(p.grad.detach()) else: grads.append(p.grad.detach()) if len(grads) == 0: if len(params) > 0: total_norm = params[0].new_tensor(0.0) else: total_norm = torch.tensor(0.0) elif len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: if multi_tensor_l2norm_available: total_norm = multi_tensor_total_norm(grads) else: if torch.cuda.is_available(): warnings.warn( "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " "you may get better performance by installing NVIDIA's apex library" ) device = torch.cuda.current_device() elif grads[0].device.type == "xla": device = grads[0].device else: device = torch.device("cpu") total_norm = torch.norm( torch.stack( [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] ) ) # calculate split_norm and all_reduce with other workers norms = [total_norm] for split_grads in [expert_grads, sharded_grads]: if len(split_grads) == 0: continue split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])) if dist.is_initialized(): split_norm.pow_(2) dist.all_reduce(split_norm) split_norm.sqrt_() norms.append(split_norm) if len(norms) > 1: total_norm = torch.norm(torch.stack(norms)) if aggregate_norm_fn is not None: total_norm = aggregate_norm_fn(total_norm) if max_norm > 0: max_norm = float(max_norm) clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) for g in grads + expert_grads + sharded_grads + base_expert_grads: g.mul_(clip_coef) return total_norm