diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 3ac97f25..abd8c6c8 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -25,7 +25,7 @@ class ConfigurableStep(Module): self.gen_outputs = opt_step['generator_outputs'] self.loss_accumulator = LossAccumulator(buffer_sz=opt_get(opt_step, ['loss_log_buffer'], 50)) self.optimizers = None - self.scaler = GradScaler(enabled=self.opt['fp16']) + self.scaler = GradScaler(enabled=self.opt['fp16'] or opt_get(self.opt, ['grad_scaler_enabled'], False)) self.grads_generated = False self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999 self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)