86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
# Copyright (c) 2022 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
import math
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
|
|
|
|
|
@torch.no_grad()
|
|
def clip_grad_norm_(
|
|
params, max_norm, moe_expert_count, aggregate_norm_fn=None
|
|
) -> torch.Tensor:
|
|
def grad_exists(p):
|
|
return p is not None and getattr(p, "grad", None) is not None
|
|
|
|
if isinstance(params, torch.Tensor):
|
|
params = [params]
|
|
params = list(params)
|
|
params = list(filter(grad_exists, params))
|
|
grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], []
|
|
denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count))
|
|
for p in params:
|
|
if hasattr(p, "expert"):
|
|
expert_grads.append(p.grad.detach() / denom)
|
|
elif hasattr(p, "base_expert"):
|
|
base_expert_grads.append(p.grad.detach())
|
|
elif hasattr(p, "_is_sharded"):
|
|
sharded_grads.append(p.grad.detach())
|
|
else:
|
|
grads.append(p.grad.detach())
|
|
if len(grads) == 0:
|
|
if len(params) > 0:
|
|
total_norm = params[0].new_tensor(0.0)
|
|
else:
|
|
total_norm = torch.tensor(0.0)
|
|
elif len(grads) == 1:
|
|
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
|
|
else:
|
|
if multi_tensor_l2norm_available:
|
|
total_norm = multi_tensor_total_norm(grads)
|
|
else:
|
|
if torch.cuda.is_available():
|
|
warnings.warn(
|
|
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
|
|
"you may get better performance by installing NVIDIA's apex library"
|
|
)
|
|
device = torch.cuda.current_device()
|
|
elif grads[0].device.type == "xla":
|
|
device = grads[0].device
|
|
else:
|
|
device = torch.device("cpu")
|
|
total_norm = torch.norm(
|
|
torch.stack(
|
|
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
|
|
)
|
|
)
|
|
|
|
# calculate split_norm and all_reduce with other workers
|
|
norms = [total_norm]
|
|
for split_grads in [expert_grads, sharded_grads]:
|
|
if len(split_grads) == 0:
|
|
continue
|
|
split_norm = torch.norm(
|
|
torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
|
|
)
|
|
if dist.is_initialized():
|
|
split_norm.pow_(2)
|
|
dist.all_reduce(split_norm)
|
|
split_norm.sqrt_()
|
|
norms.append(split_norm)
|
|
if len(norms) > 1:
|
|
total_norm = torch.norm(torch.stack(norms))
|
|
|
|
if aggregate_norm_fn is not None:
|
|
total_norm = aggregate_norm_fn(total_norm)
|
|
|
|
if max_norm > 0:
|
|
max_norm = float(max_norm)
|
|
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
|
for g in grads + expert_grads + sharded_grads + base_expert_grads:
|
|
g.mul_(clip_coef)
|
|
return total_norm
|