forked from mrq/DL-Art-School
Fix
This commit is contained in:
parent
dd585df772
commit
29e07913a8
|
@ -139,9 +139,9 @@ class ConfigurableStep(Module):
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt_unweighted._config = opt_config
|
opt_unweighted._config = opt_config
|
||||||
opt_unweighted._config['network'] = net_name
|
opt_unweighted._config['network'] = net_name
|
||||||
self.optimizers.append(opt_unweighted)
|
|
||||||
# Not setting these means abnormal gradient detection below no longer works.
|
|
||||||
opt_unweighted._group_names = []
|
opt_unweighted._group_names = []
|
||||||
|
self.optimizers.append(opt_unweighted)
|
||||||
|
|
||||||
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'],
|
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'],
|
||||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
|
@ -160,6 +160,7 @@ class ConfigurableStep(Module):
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt_unweighted._config = opt_config
|
opt_unweighted._config = opt_config
|
||||||
opt_unweighted._config['network'] = net_name
|
opt_unweighted._config['network'] = net_name
|
||||||
|
opt_unweighted._group_names = []
|
||||||
self.optimizers.append(opt_unweighted)
|
self.optimizers.append(opt_unweighted)
|
||||||
|
|
||||||
opt = Lamb(params_weights, lr=opt_config['lr'],
|
opt = Lamb(params_weights, lr=opt_config['lr'],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user