diff --git a/codes/trainer/optimizers/lamb.py b/codes/trainer/optimizers/lamb.py new file mode 100644 index 00000000..9a1cbf3e --- /dev/null +++ b/codes/trainer/optimizers/lamb.py @@ -0,0 +1,131 @@ +""" +Lamb optimizer. + +Adapted from original source: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py +""" + +import collections +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + r"""Implements Lamb algorithm. + + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0, adam=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Paper v3 does not use debiasing. + # bias_correction1 = 1 - beta1 ** state['step'] + # bias_correction2 = 1 - beta2 ** state['step'] + # Apply bias to lr to avoid broadcast. + step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 + + weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) + if group['weight_decay'] != 0: + adam_step.add_(p.data, alpha=group['weight_decay']) + + adam_norm = adam_step.pow(2).sum().sqrt() + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + state['weight_norm'] = weight_norm + state['adam_norm'] = adam_norm + state['trust_ratio'] = trust_ratio + if self.adam: + trust_ratio = 1 + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + + return loss + + + def debug(self): + """Returns a histogram dict for recording various norms and the trust ratio.""" + results = collections.defaultdict(list) + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + for i in ('weight_norm', 'adam_norm', 'trust_ratio'): + if i in state: + results[i].append(state[i]) + + res = {} + for k, v in results.items(): + res[f'histogram_lamb_{k}'] = torch.tensor(v) + return res diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 4a537fd0..22243030 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -154,6 +154,18 @@ class ConfigurableStep(Module): weight_decay=opt_config['weight_decay']) opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient']) opt._group_names = sorted(list(all_param_names)) + elif self.step_opt['optimizer'] == 'lamb': + from trainer.optimizers.lamb import Lamb + opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, + betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) + opt_unweighted._config = opt_config + opt_unweighted._config['network'] = net_name + self.optimizers.append(opt_unweighted) + + opt = Lamb(params_weights, lr=opt_config['lr'], + weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), + betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) + opt._group_names = [params_names_weights, params_names_notweights] elif self.step_opt['optimizer'] == 'sgd': from torch.optim import SGD opt = SGD(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay'])