Allow grad scaler to be enabled even in fp32 mode

This commit is contained in:
James Betker 2022-01-21 23:13:24 -07:00
parent 91b4b240ac
commit ce929a6b3f

View File

@ -25,7 +25,7 @@ class ConfigurableStep(Module):
self.gen_outputs = opt_step['generator_outputs']
self.loss_accumulator = LossAccumulator(buffer_sz=opt_get(opt_step, ['loss_log_buffer'], 50))
self.optimizers = None
self.scaler = GradScaler(enabled=self.opt['fp16'])
self.scaler = GradScaler(enabled=self.opt['fp16'] or opt_get(self.opt, ['grad_scaler_enabled'], False))
self.grads_generated = False
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)