forked from mrq/DL-Art-School
...
This commit is contained in:
parent
29e07913a8
commit
8f767b8b4f
|
@ -166,7 +166,7 @@ class ConfigurableStep(Module):
|
||||||
opt = Lamb(params_weights, lr=opt_config['lr'],
|
opt = Lamb(params_weights, lr=opt_config['lr'],
|
||||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt._group_names = [params_names_weights, params_names_notweights]
|
opt._group_names = []
|
||||||
elif self.step_opt['optimizer'] == 'sgd':
|
elif self.step_opt['optimizer'] == 'sgd':
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])
|
opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user