From 64cb4a92dbc9acdf6ce1d892d270a34690e66bf9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 25 Dec 2021 21:32:01 -0700 Subject: [PATCH] Support adamw_zero --- codes/trainer/steps.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 13ac185e..f399beea 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -1,4 +1,5 @@ from torch.cuda.amp import GradScaler +from torch.distributed.optim import ZeroRedundancyOptimizer from utils.loss_accumulator import LossAccumulator from torch.nn import Module @@ -130,6 +131,21 @@ class ConfigurableStep(Module): weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt._group_names = [params_names_weights, params_names_notweights] + elif self.step_opt['optimizer'] == 'adamw_zero': + # The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted + # parameters and just use a normal AdamW implementation. In a large network, these weights will normally + # be a tiny fraction of the total weights. + opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, + betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) + opt_unweighted._config = opt_config + opt_unweighted._config['network'] = net_name + self.optimizers.append(opt_unweighted) + # Not setting these means abnormal gradient detection below no longer works. + opt_unweighted._group_names = [] + opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'], + weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), + betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) + opt._group_names = [] elif self.step_opt['optimizer'] == 'lars': from trainer.optimizers.larc import LARC from trainer.optimizers.sgd import SGDNoBiasMomentum