Allow grad scaler to be enabled even in fp32 mode
This commit is contained in:
parent
91b4b240ac
commit
ce929a6b3f
|
@ -25,7 +25,7 @@ class ConfigurableStep(Module):
|
||||||
self.gen_outputs = opt_step['generator_outputs']
|
self.gen_outputs = opt_step['generator_outputs']
|
||||||
self.loss_accumulator = LossAccumulator(buffer_sz=opt_get(opt_step, ['loss_log_buffer'], 50))
|
self.loss_accumulator = LossAccumulator(buffer_sz=opt_get(opt_step, ['loss_log_buffer'], 50))
|
||||||
self.optimizers = None
|
self.optimizers = None
|
||||||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
self.scaler = GradScaler(enabled=self.opt['fp16'] or opt_get(self.opt, ['grad_scaler_enabled'], False))
|
||||||
self.grads_generated = False
|
self.grads_generated = False
|
||||||
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
|
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
|
||||||
self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)
|
self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user