From 07731d5491d9e7696c996989309a8fd66468d229 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 24 Mar 2022 21:20:22 -0600 Subject: [PATCH] Fix ET --- codes/trainer/ExtensibleTrainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index f78f1f47..89d5e753 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -318,7 +318,7 @@ class ExtensibleTrainer(BaseModel): # This is needed to accurately log the grad norms. for opt in step.optimizers: from torch.cuda.amp.grad_scaler import OptState - if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED: + if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED: step.scaler.unscale_(opt) if return_grad_norms and train_step: