diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 955f50a..144e474 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -146,7 +146,7 @@ def load_engines(training=True, **model_kwargs): elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad elif cfg.hyperparameters.optimizer.lower() == "muon": - optimizer = ml.Muon + optimizer_class = ml.Muon muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] @@ -221,9 +221,6 @@ def load_engines(training=True, **model_kwargs): continue state[k] = ml.resize_weight( state[k], tokens ) - for k, v in last_embedding_keys.items(): - state[k][-1] = v - model.load_state_dict(state, strict=cfg.trainer.strict_loading) # load lora weights if exists