diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 0a470e0f..f78f1f47 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -317,7 +317,9 @@ class ExtensibleTrainer(BaseModel): # Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point) # This is needed to accurately log the grad norms. for opt in step.optimizers: - step.scaler.unscale_(opt) + from torch.cuda.amp.grad_scaler import OptState + if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED: + step.scaler.unscale_(opt) if return_grad_norms and train_step: for name in nets_to_train: