Fix likely defective nan grad recovery

This commit is contained in:
James Betker 2022-01-08 18:24:58 -07:00
parent 438dd9ed33
commit 2a9a25e6e7

View File

@ -321,6 +321,10 @@ class ConfigurableStep(Module):
if not nan_found:
self.scaler.step(opt)
self.scaler.update()
else:
for pg in opt.param_groups:
for p in pg['params']:
p.grad = 0
def get_metrics(self):
return self.loss_accumulator.as_dict()