forked from mrq/DL-Art-School
Allow Novograd to be used as an optimizer
This commit is contained in:
parent
912a4d9fea
commit
21ae135f23
|
@ -6,6 +6,7 @@ import torch
|
|||
from apex import amp
|
||||
from collections import OrderedDict
|
||||
from .injectors import create_injector
|
||||
from apex.optimizers import FusedNovoGrad
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -50,9 +51,13 @@ class ConfigurableStep(Module):
|
|||
else:
|
||||
if self.env['rank'] <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
|
||||
weight_decay=self.step_opt['weight_decay'],
|
||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
|
||||
weight_decay=self.step_opt['weight_decay'],
|
||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'novograd':
|
||||
opt = FusedNovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'],
|
||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||
self.optimizers = [opt]
|
||||
|
||||
# Returns all optimizers used in this step.
|
||||
|
|
Loading…
Reference in New Issue
Block a user