diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 13ac185e..f399beea 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -1,4 +1,5 @@ from torch.cuda.amp import GradScaler +from torch.distributed.optim import ZeroRedundancyOptimizer from utils.loss_accumulator import LossAccumulator from torch.nn import Module @@ -130,6 +131,21 @@ class ConfigurableStep(Module): weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt._group_names = [params_names_weights, params_names_notweights] + elif self.step_opt['optimizer'] == 'adamw_zero': + # The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted + # parameters and just use a normal AdamW implementation. In a large network, these weights will normally + # be a tiny fraction of the total weights. + opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, + betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) + opt_unweighted._config = opt_config + opt_unweighted._config['network'] = net_name + self.optimizers.append(opt_unweighted) + # Not setting these means abnormal gradient detection below no longer works. + opt_unweighted._group_names = [] + opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'], + weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), + betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) + opt._group_names = [] elif self.step_opt['optimizer'] == 'lars': from trainer.optimizers.larc import LARC from trainer.optimizers.sgd import SGDNoBiasMomentum