From e8613041c08c8e5f0c7d073a63dc5f8731a35b17 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 6 Sep 2020 17:27:08 -0600 Subject: [PATCH] Add novograd optimizer --- .../archs/SwitchedResidualGenerator_arch.py | 1 - codes/models/novograd.py | 71 +++++++++++++++++++ codes/models/steps/steps.py | 4 +- codes/test.py | 2 +- codes/train.py | 2 +- 5 files changed, 75 insertions(+), 5 deletions(-) create mode 100644 codes/models/novograd.py diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index b3d62ec0..4f0bac56 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -565,4 +565,3 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] return val - diff --git a/codes/models/novograd.py b/codes/models/novograd.py new file mode 100644 index 00000000..374479ad --- /dev/null +++ b/codes/models/novograd.py @@ -0,0 +1,71 @@ +# Author Masashi Kimura (Convergence Lab.) +import torch +from torch import optim +import math + +class NovoGrad(optim.Optimizer): + def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(NovoGrad, self).__init__(params, defaults) + self._lr = lr + self._beta1 = betas[0] + self._beta2 = betas[1] + self._eps = eps + self._wd = weight_decay + self._grad_averaging = grad_averaging + + self._momentum_initialized = False + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + if not self._momentum_initialized: + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('NovoGrad does not support sparse gradients') + + v = torch.norm(grad)**2 + m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data + state['step'] = 0 + state['v'] = v + state['m'] = m + state['grad_ema'] = None + self._momentum_initialized = True + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + state['step'] += 1 + + step, v, m = state['step'], state['v'], state['m'] + grad_ema = state['grad_ema'] + + grad = p.grad.data + g2 = torch.norm(grad)**2 + grad_ema = g2 if grad_ema is None else grad_ema * \ + self._beta2 + g2*(1. - self._beta2) + grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) + + if self._grad_averaging: + grad *= (1. - self._beta1) + + g2 = torch.norm(grad)**2 + v = self._beta2*v + (1. - self._beta2)*g2 + m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd*p.data) + bias_correction1 = 1 - self._beta1 ** step + bias_correction2 = 1 - self._beta2 ** step + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + state['v'], state['m'] = v, m + state['grad_ema'] = grad_ema + p.data.add_(-step_size, m) + return loss \ No newline at end of file diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 8671e30c..7d845029 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -6,7 +6,7 @@ import torch from apex import amp from collections import OrderedDict from .injectors import create_injector -from apex.optimizers import FusedNovoGrad +from models.novograd import NovoGrad logger = logging.getLogger('base') @@ -56,7 +56,7 @@ class ConfigurableStep(Module): weight_decay=self.step_opt['weight_decay'], betas=(self.step_opt['beta1'], self.step_opt['beta2'])) elif self.step_opt['optimizer'] == 'novograd': - opt = FusedNovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'], + opt = NovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'], betas=(self.step_opt['beta1'], self.step_opt['beta2'])) self.optimizers = [opt] diff --git a/codes/test.py b/codes/test.py index 4c96f14c..bf1061ef 100644 --- a/codes/test.py +++ b/codes/test.py @@ -89,7 +89,7 @@ if __name__ == "__main__": want_just_images = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/analyze_srg.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) diff --git a/codes/train.py b/codes/train.py index 6f63b7d6..76fd9a8e 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_corrupt_imgset_rrdb.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref_gan_no_branch.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)