This commit is contained in:
James Betker 2022-02-15 07:08:17 -07:00
parent 29e07913a8
commit 8f767b8b4f

View File

@ -141,7 +141,7 @@ class ConfigurableStep(Module):
opt_unweighted._config['network'] = net_name
opt_unweighted._group_names = []
self.optimizers.append(opt_unweighted)
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
@ -166,7 +166,7 @@ class ConfigurableStep(Module):
opt = Lamb(params_weights, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt._group_names = [params_names_weights, params_names_notweights]
opt._group_names = []
elif self.step_opt['optimizer'] == 'sgd':
from torch.optim import SGD
opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])