fix unscaler

This commit is contained in:
James Betker 2022-03-22 11:40:02 -06:00
parent 5405ce4363
commit 963f0e9cee

View File

@ -317,7 +317,9 @@ class ExtensibleTrainer(BaseModel):
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
# This is needed to accurately log the grad norms.
for opt in step.optimizers:
step.scaler.unscale_(opt)
from torch.cuda.amp.grad_scaler import OptState
if step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
step.scaler.unscale_(opt)
if return_grad_norms and train_step:
for name in nets_to_train: