forked from mrq/DL-Art-School
Allow Novograd to be used as an optimizer
This commit is contained in:
parent
912a4d9fea
commit
21ae135f23
|
@ -6,6 +6,7 @@ import torch
|
||||||
from apex import amp
|
from apex import amp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from .injectors import create_injector
|
from .injectors import create_injector
|
||||||
|
from apex.optimizers import FusedNovoGrad
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -50,9 +51,13 @@ class ConfigurableStep(Module):
|
||||||
else:
|
else:
|
||||||
if self.env['rank'] <= 0:
|
if self.env['rank'] <= 0:
|
||||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
|
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||||
weight_decay=self.step_opt['weight_decay'],
|
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
|
||||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
weight_decay=self.step_opt['weight_decay'],
|
||||||
|
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||||
|
elif self.step_opt['optimizer'] == 'novograd':
|
||||||
|
opt = FusedNovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'],
|
||||||
|
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||||
self.optimizers = [opt]
|
self.optimizers = [opt]
|
||||||
|
|
||||||
# Returns all optimizers used in this step.
|
# Returns all optimizers used in this step.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user