diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 8b933f4..97bd9fc 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -147,8 +147,8 @@ def load_engines(training=True): state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens] # resize text embedding - if "rvq_level_emb.weight" in state and model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]: - state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels] + if "rvq_l_emb.weight" in state and model.config.resp_levels != state["rvq_l_emb.weight"].shape[0]: + state["rvq_l_emb.weight"] = state["rvq_l_emb.weight"][:model.config.resp_levels] model.load_state_dict(state, strict=cfg.trainer.strict_loading)