Fixes to weight_decay in adamw

This commit is contained in:
James Betker 2021-07-31 15:58:41 -06:00
parent 0c9e75bc69
commit 965f6e6b52

View File

@ -77,38 +77,47 @@ class ConfigurableStep(Module):
for k, pg in opt_config['param_groups'].items():
optim_params[k] = {'params': [], 'lr': pg['lr']}
for k, v in net.named_parameters(): # can optimize for a part of the model
# Make some inference about these parameters, which can be used by some optimizers to treat certain
# parameters differently. For example, it is considered good practice to not do weight decay on
# BN & bias parameters. TODO: process the module tree instead of the parameter tree to accomplish the
# same thing, but in a more effective way.
if k.endswith(".bias"):
v.is_bias = True
if k.endswith(".weight"):
v.is_weight = True
if ".bn" in k or '.batchnorm' in k or '.bnorm' in k:
v.is_bn = True
# Some models can specify some parameters to be in different groups.
param_group = "default"
if hasattr(v, 'PARAM_GROUP'):
if v.PARAM_GROUP in optim_params.keys():
param_group = v.PARAM_GROUP
else:
logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. '
f'The same LR will be used for all parameters.')
import torch.nn as nn
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
emb_modules = (nn.Embedding, nn.EmbeddingBag)
params_notweights = set()
for mn, m in net.named_modules():
for k, v in m.named_parameters():
v.is_bias = k.endswith(".bias")
v.is_weight = k.endswith(".weight")
v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules)
if v.requires_grad:
optim_params[param_group]['params'].append(v)
else:
if self.env['rank'] <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
if v.is_bias or v.is_norm or v.is_emb:
params_notweights.add(v)
# Some models can specify some parameters to be in different groups.
param_group = "default"
if hasattr(v, 'PARAM_GROUP'):
if v.PARAM_GROUP in optim_params.keys():
param_group = v.PARAM_GROUP
else:
logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. '
f'The same LR will be used for all parameters.')
if v.requires_grad:
optim_params[param_group]['params'].append(v)
else:
if self.env['rank'] <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
params_weights = set(net.parameters()) ^ params_notweights
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
opt = torch.optim.Adam(list(optim_params.values()),
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'adamw':
opt = torch.optim.AdamW(list(optim_params.values()),
groups = [
{ 'params': list(params_weights), 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
{ 'params': list(params_notweights), 'weight_decay': 0 }
]
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
elif self.step_opt['optimizer'] == 'lars':