From 965f6e6b528e3467d5977fea34c011c2f15d46c5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 31 Jul 2021 15:58:41 -0600 Subject: [PATCH] Fixes to weight_decay in adamw --- codes/trainer/steps.py | 61 ++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index e269e1c3..2d82f5fa 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -77,38 +77,47 @@ class ConfigurableStep(Module): for k, pg in opt_config['param_groups'].items(): optim_params[k] = {'params': [], 'lr': pg['lr']} - for k, v in net.named_parameters(): # can optimize for a part of the model - # Make some inference about these parameters, which can be used by some optimizers to treat certain - # parameters differently. For example, it is considered good practice to not do weight decay on - # BN & bias parameters. TODO: process the module tree instead of the parameter tree to accomplish the - # same thing, but in a more effective way. - if k.endswith(".bias"): - v.is_bias = True - if k.endswith(".weight"): - v.is_weight = True - if ".bn" in k or '.batchnorm' in k or '.bnorm' in k: - v.is_bn = True - # Some models can specify some parameters to be in different groups. - param_group = "default" - if hasattr(v, 'PARAM_GROUP'): - if v.PARAM_GROUP in optim_params.keys(): - param_group = v.PARAM_GROUP - else: - logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. ' - f'The same LR will be used for all parameters.') + import torch.nn as nn + norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, + nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) + emb_modules = (nn.Embedding, nn.EmbeddingBag) + params_notweights = set() + for mn, m in net.named_modules(): + for k, v in m.named_parameters(): + v.is_bias = k.endswith(".bias") + v.is_weight = k.endswith(".weight") + v.is_norm = isinstance(m, norm_modules) + v.is_emb = isinstance(m, emb_modules) - if v.requires_grad: - optim_params[param_group]['params'].append(v) - else: - if self.env['rank'] <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) + if v.is_bias or v.is_norm or v.is_emb: + params_notweights.add(v) + + # Some models can specify some parameters to be in different groups. + param_group = "default" + if hasattr(v, 'PARAM_GROUP'): + if v.PARAM_GROUP in optim_params.keys(): + param_group = v.PARAM_GROUP + else: + logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. ' + f'The same LR will be used for all parameters.') + + if v.requires_grad: + optim_params[param_group]['params'].append(v) + else: + if self.env['rank'] <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + params_weights = set(net.parameters()) ^ params_notweights if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': - opt = torch.optim.Adam(list(optim_params.values()), + opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'], weight_decay=opt_config['weight_decay'], betas=(opt_config['beta1'], opt_config['beta2'])) elif self.step_opt['optimizer'] == 'adamw': - opt = torch.optim.AdamW(list(optim_params.values()), + groups = [ + { 'params': list(params_weights), 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) }, + { 'params': list(params_notweights), 'weight_decay': 0 } + ] + opt = torch.optim.AdamW(groups, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) elif self.step_opt['optimizer'] == 'lars':