diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 6d3a871..ae20ad8 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -493,10 +493,11 @@ class Engines(dict[str, Engine]): total_elapsed_time += elapsed_time grad_norm = engine.get_global_grad_norm() loss_scale = 1 - if hasattr(engine.optimizer, "loss_scale"): + if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None: loss_scale = engine.optimizer.loss_scale - grad_norm /= loss_scale + if grad_norm is not None: + grad_norm /= loss_scale stats.update( flatten_dict(