forgot I renamed this

This commit is contained in:
mrq 2024-06-09 11:12:30 -05:00
parent 80f9530840
commit 8d92dac829

View File

@ -147,8 +147,8 @@ def load_engines(training=True):
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
# resize text embedding
if "rvq_level_emb.weight" in state and model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
if "rvq_l_emb.weight" in state and model.config.resp_levels != state["rvq_l_emb.weight"].shape[0]:
state["rvq_l_emb.weight"] = state["rvq_l_emb.weight"][:model.config.resp_levels]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)