From 8f767b8b4f70a11b081ef35e75045065017571ce Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 15 Feb 2022 07:08:17 -0700 Subject: [PATCH] ... --- codes/trainer/steps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 040dee00..247155fa 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -141,7 +141,7 @@ class ConfigurableStep(Module): opt_unweighted._config['network'] = net_name opt_unweighted._group_names = [] self.optimizers.append(opt_unweighted) - + opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) @@ -166,7 +166,7 @@ class ConfigurableStep(Module): opt = Lamb(params_weights, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) - opt._group_names = [params_names_weights, params_names_notweights] + opt._group_names = [] elif self.step_opt['optimizer'] == 'sgd': from torch.optim import SGD opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])