diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index f78f1f47..89d5e753 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -318,7 +318,7 @@ class ExtensibleTrainer(BaseModel): # This is needed to accurately log the grad norms. for opt in step.optimizers: from torch.cuda.amp.grad_scaler import OptState - if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED: + if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED: step.scaler.unscale_(opt) if return_grad_norms and train_step: