forgot I renamed this
This commit is contained in:
parent
80f9530840
commit
8d92dac829
|
@ -147,8 +147,8 @@ def load_engines(training=True):
|
|||
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
||||
|
||||
# resize text embedding
|
||||
if "rvq_level_emb.weight" in state and model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
|
||||
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
|
||||
if "rvq_l_emb.weight" in state and model.config.resp_levels != state["rvq_l_emb.weight"].shape[0]:
|
||||
state["rvq_l_emb.weight"] = state["rvq_l_emb.weight"][:model.config.resp_levels]
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user