forked from mrq/DL-Art-School
Support adamw_zero
This commit is contained in:
parent
776a7abfcc
commit
64cb4a92db
|
@ -1,4 +1,5 @@
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
|
|
||||||
from utils.loss_accumulator import LossAccumulator
|
from utils.loss_accumulator import LossAccumulator
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
@ -130,6 +131,21 @@ class ConfigurableStep(Module):
|
||||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt._group_names = [params_names_weights, params_names_notweights]
|
opt._group_names = [params_names_weights, params_names_notweights]
|
||||||
|
elif self.step_opt['optimizer'] == 'adamw_zero':
|
||||||
|
# The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted
|
||||||
|
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally
|
||||||
|
# be a tiny fraction of the total weights.
|
||||||
|
opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
||||||
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
|
opt_unweighted._config = opt_config
|
||||||
|
opt_unweighted._config['network'] = net_name
|
||||||
|
self.optimizers.append(opt_unweighted)
|
||||||
|
# Not setting these means abnormal gradient detection below no longer works.
|
||||||
|
opt_unweighted._group_names = []
|
||||||
|
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'],
|
||||||
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
|
opt._group_names = []
|
||||||
elif self.step_opt['optimizer'] == 'lars':
|
elif self.step_opt['optimizer'] == 'lars':
|
||||||
from trainer.optimizers.larc import LARC
|
from trainer.optimizers.larc import LARC
|
||||||
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||||
|
|
Loading…
Reference in New Issue
Block a user