diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 955f50a..2a28f1d 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -146,7 +146,7 @@ def load_engines(training=True, **model_kwargs): elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad elif cfg.hyperparameters.optimizer.lower() == "muon": - optimizer = ml.Muon + optimizer_class = ml.Muon muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ]