torchscale/examples/fairseq/utils/sparse_clip.py

86 lines
3.0 KiB
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import warnings
import torch
import torch.distributed as dist
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
@torch.no_grad()
def clip_grad_norm_(
params, max_norm, moe_expert_count, aggregate_norm_fn=None
) -> torch.Tensor:
def grad_exists(p):
return p is not None and getattr(p, "grad", None) is not None
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
params = list(filter(grad_exists, params))
grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], []
denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count))
for p in params:
if hasattr(p, "expert"):
expert_grads.append(p.grad.detach() / denom)
elif hasattr(p, "base_expert"):
base_expert_grads.append(p.grad.detach())
elif hasattr(p, "_is_sharded"):
sharded_grads.append(p.grad.detach())
else:
grads.append(p.grad.detach())
if len(grads) == 0:
if len(params) > 0:
total_norm = params[0].new_tensor(0.0)
else:
total_norm = torch.tensor(0.0)
elif len(grads) == 1:
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
else:
if multi_tensor_l2norm_available:
total_norm = multi_tensor_total_norm(grads)
else:
if torch.cuda.is_available():
warnings.warn(
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
"you may get better performance by installing NVIDIA's apex library"
)
device = torch.cuda.current_device()
elif grads[0].device.type == "xla":
device = grads[0].device
else:
device = torch.device("cpu")
total_norm = torch.norm(
torch.stack(
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
)
)
# calculate split_norm and all_reduce with other workers
norms = [total_norm]
for split_grads in [expert_grads, sharded_grads]:
if len(split_grads) == 0:
continue
split_norm = torch.norm(
torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
)
if dist.is_initialized():
split_norm.pow_(2)
dist.all_reduce(split_norm)
split_norm.sqrt_()
norms.append(split_norm)
if len(norms) > 1:
total_norm = torch.norm(torch.stack(norms))
if aggregate_norm_fn is not None:
total_norm = aggregate_norm_fn(total_norm)
if max_norm > 0:
max_norm = float(max_norm)
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
for g in grads + expert_grads + sharded_grads + base_expert_grads:
g.mul_(clip_coef)
return total_norm