fix unscaler
This commit is contained in:
parent
5405ce4363
commit
963f0e9cee
|
@ -317,6 +317,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
|
||||
# This is needed to accurately log the grad norms.
|
||||
for opt in step.optimizers:
|
||||
from torch.cuda.amp.grad_scaler import OptState
|
||||
if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
|
||||
step.scaler.unscale_(opt)
|
||||
|
||||
if return_grad_norms and train_step:
|
||||
|
|
Loading…
Reference in New Issue
Block a user