From 29e07913a8a05f002d1a3082933f59077305f685 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 15 Feb 2022 06:58:11 -0700 Subject: [PATCH] Fix --- codes/trainer/steps.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 22243030..040dee00 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -139,9 +139,9 @@ class ConfigurableStep(Module): betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt_unweighted._config = opt_config opt_unweighted._config['network'] = net_name - self.optimizers.append(opt_unweighted) - # Not setting these means abnormal gradient detection below no longer works. opt_unweighted._group_names = [] + self.optimizers.append(opt_unweighted) + opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) @@ -160,6 +160,7 @@ class ConfigurableStep(Module): betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt_unweighted._config = opt_config opt_unweighted._config['network'] = net_name + opt_unweighted._group_names = [] self.optimizers.append(opt_unweighted) opt = Lamb(params_weights, lr=opt_config['lr'],