This commit is contained in:
James Betker 2022-03-24 21:20:22 -06:00
parent a15970dd97
commit 07731d5491

View File

@ -318,7 +318,7 @@ class ExtensibleTrainer(BaseModel):
# This is needed to accurately log the grad norms.
for opt in step.optimizers:
from torch.cuda.amp.grad_scaler import OptState
if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
step.scaler.unscale_(opt)
if return_grad_norms and train_step: