diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 7af98b31..8671e30c 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -6,6 +6,7 @@ import torch from apex import amp from collections import OrderedDict from .injectors import create_injector +from apex.optimizers import FusedNovoGrad logger = logging.getLogger('base') @@ -50,9 +51,13 @@ class ConfigurableStep(Module): else: if self.env['rank'] <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) - opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'], - weight_decay=self.step_opt['weight_decay'], - betas=(self.step_opt['beta1'], self.step_opt['beta2'])) + if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': + opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'], + weight_decay=self.step_opt['weight_decay'], + betas=(self.step_opt['beta1'], self.step_opt['beta2'])) + elif self.step_opt['optimizer'] == 'novograd': + opt = FusedNovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'], + betas=(self.step_opt['beta1'], self.step_opt['beta2'])) self.optimizers = [opt] # Returns all optimizers used in this step.