forked from mrq/DL-Art-School
Add novograd optimizer
This commit is contained in:
parent
a5c2388368
commit
e8613041c0
|
@ -565,4 +565,3 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
|||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
|
71
codes/models/novograd.py
Normal file
71
codes/models/novograd.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
# Author Masashi Kimura (Convergence Lab.)
|
||||
import torch
|
||||
from torch import optim
|
||||
import math
|
||||
|
||||
class NovoGrad(optim.Optimizer):
|
||||
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(NovoGrad, self).__init__(params, defaults)
|
||||
self._lr = lr
|
||||
self._beta1 = betas[0]
|
||||
self._beta2 = betas[1]
|
||||
self._eps = eps
|
||||
self._wd = weight_decay
|
||||
self._grad_averaging = grad_averaging
|
||||
|
||||
self._momentum_initialized = False
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
if not self._momentum_initialized:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('NovoGrad does not support sparse gradients')
|
||||
|
||||
v = torch.norm(grad)**2
|
||||
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
|
||||
state['step'] = 0
|
||||
state['v'] = v
|
||||
state['m'] = m
|
||||
state['grad_ema'] = None
|
||||
self._momentum_initialized = True
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
state['step'] += 1
|
||||
|
||||
step, v, m = state['step'], state['v'], state['m']
|
||||
grad_ema = state['grad_ema']
|
||||
|
||||
grad = p.grad.data
|
||||
g2 = torch.norm(grad)**2
|
||||
grad_ema = g2 if grad_ema is None else grad_ema * \
|
||||
self._beta2 + g2*(1. - self._beta2)
|
||||
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
|
||||
|
||||
if self._grad_averaging:
|
||||
grad *= (1. - self._beta1)
|
||||
|
||||
g2 = torch.norm(grad)**2
|
||||
v = self._beta2*v + (1. - self._beta2)*g2
|
||||
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd*p.data)
|
||||
bias_correction1 = 1 - self._beta1 ** step
|
||||
bias_correction2 = 1 - self._beta2 ** step
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
state['v'], state['m'] = v, m
|
||||
state['grad_ema'] = grad_ema
|
||||
p.data.add_(-step_size, m)
|
||||
return loss
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from apex import amp
|
||||
from collections import OrderedDict
|
||||
from .injectors import create_injector
|
||||
from apex.optimizers import FusedNovoGrad
|
||||
from models.novograd import NovoGrad
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -56,7 +56,7 @@ class ConfigurableStep(Module):
|
|||
weight_decay=self.step_opt['weight_decay'],
|
||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'novograd':
|
||||
opt = FusedNovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'],
|
||||
opt = NovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'],
|
||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||
self.optimizers = [opt]
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ if __name__ == "__main__":
|
|||
want_just_images = True
|
||||
srg_analyze = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/analyze_srg.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_corrupt_imgset_rrdb.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref_gan_no_branch.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user