From 963f0e9cee9f67542a5b37fe4893b1b3c4028a85 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 22 Mar 2022 11:40:02 -0600 Subject: [PATCH] fix unscaler --- codes/trainer/ExtensibleTrainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 0a470e0f..f78f1f47 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -317,7 +317,9 @@ class ExtensibleTrainer(BaseModel): # Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point) # This is needed to accurately log the grad norms. for opt in step.optimizers: - step.scaler.unscale_(opt) + from torch.cuda.amp.grad_scaler import OptState + if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED: + step.scaler.unscale_(opt) if return_grad_norms and train_step: for name in nets_to_train: