Fix ET
This commit is contained in:
parent
a15970dd97
commit
07731d5491
|
@ -318,7 +318,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
# This is needed to accurately log the grad norms.
|
||||
for opt in step.optimizers:
|
||||
from torch.cuda.amp.grad_scaler import OptState
|
||||
if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
|
||||
if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
|
||||
step.scaler.unscale_(opt)
|
||||
|
||||
if return_grad_norms and train_step:
|
||||
|
|
Loading…
Reference in New Issue
Block a user