forked from mrq/DL-Art-School
Fixes to weight_decay in adamw
This commit is contained in:
parent
0c9e75bc69
commit
965f6e6b52
|
@ -77,38 +77,47 @@ class ConfigurableStep(Module):
|
|||
for k, pg in opt_config['param_groups'].items():
|
||||
optim_params[k] = {'params': [], 'lr': pg['lr']}
|
||||
|
||||
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||
# Make some inference about these parameters, which can be used by some optimizers to treat certain
|
||||
# parameters differently. For example, it is considered good practice to not do weight decay on
|
||||
# BN & bias parameters. TODO: process the module tree instead of the parameter tree to accomplish the
|
||||
# same thing, but in a more effective way.
|
||||
if k.endswith(".bias"):
|
||||
v.is_bias = True
|
||||
if k.endswith(".weight"):
|
||||
v.is_weight = True
|
||||
if ".bn" in k or '.batchnorm' in k or '.bnorm' in k:
|
||||
v.is_bn = True
|
||||
# Some models can specify some parameters to be in different groups.
|
||||
param_group = "default"
|
||||
if hasattr(v, 'PARAM_GROUP'):
|
||||
if v.PARAM_GROUP in optim_params.keys():
|
||||
param_group = v.PARAM_GROUP
|
||||
else:
|
||||
logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. '
|
||||
f'The same LR will be used for all parameters.')
|
||||
import torch.nn as nn
|
||||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
||||
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
||||
params_notweights = set()
|
||||
for mn, m in net.named_modules():
|
||||
for k, v in m.named_parameters():
|
||||
v.is_bias = k.endswith(".bias")
|
||||
v.is_weight = k.endswith(".weight")
|
||||
v.is_norm = isinstance(m, norm_modules)
|
||||
v.is_emb = isinstance(m, emb_modules)
|
||||
|
||||
if v.requires_grad:
|
||||
optim_params[param_group]['params'].append(v)
|
||||
else:
|
||||
if self.env['rank'] <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
if v.is_bias or v.is_norm or v.is_emb:
|
||||
params_notweights.add(v)
|
||||
|
||||
# Some models can specify some parameters to be in different groups.
|
||||
param_group = "default"
|
||||
if hasattr(v, 'PARAM_GROUP'):
|
||||
if v.PARAM_GROUP in optim_params.keys():
|
||||
param_group = v.PARAM_GROUP
|
||||
else:
|
||||
logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. '
|
||||
f'The same LR will be used for all parameters.')
|
||||
|
||||
if v.requires_grad:
|
||||
optim_params[param_group]['params'].append(v)
|
||||
else:
|
||||
if self.env['rank'] <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
params_weights = set(net.parameters()) ^ params_notweights
|
||||
|
||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||
opt = torch.optim.Adam(list(optim_params.values()),
|
||||
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
|
||||
weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'adamw':
|
||||
opt = torch.optim.AdamW(list(optim_params.values()),
|
||||
groups = [
|
||||
{ 'params': list(params_weights), 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
||||
{ 'params': list(params_notweights), 'weight_decay': 0 }
|
||||
]
|
||||
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
|
||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||
elif self.step_opt['optimizer'] == 'lars':
|
||||
|
|
Loading…
Reference in New Issue
Block a user