This commit is contained in:
mrq 2024-11-21 15:07:46 -06:00
parent 6845c447c9
commit 3fc0540f49
2 changed files with 46 additions and 1 deletions

View File

@ -326,7 +326,7 @@ class Model:
if isinstance(self.size, dict):
if hasattr(self.size, "label") and self.size['label']:
name.append(f"{self.size['label']}")
elif isinstance(self.size, str) and self.size != "full":
elif isinstance(self.size, str) and self.size not in ["full","extended"]:
name.append(self.size)
if self.experts > 1:
@ -392,6 +392,8 @@ class Model:
if self.size == "double":
return 24
if self.size == "extended":
return 16
return 12
@property

View File

@ -231,6 +231,49 @@ def load_engines(training=True, **model_kwargs):
del state['resps_emb.embeddings.8.weight']
"""
"""
if True:
remapped_dict = {}
remapped_indices = [
(0, 1),
(1, 2),
(2, 3),
(3, 5),
(4, 6),
(5, 7),
(6, 9),
(7, 10),
(8, 11),
(9, 13),
(10, 14),
(11, 15),
]
for src, dst in remapped_indices:
remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"]
remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"]
remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"]
remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"]
remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"]
remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"]
remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"]
remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"]
remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"]
del state[f"model.layers.{src}.input_layernorm.weight"]
del state[f"model.layers.{src}.self_attn.k_proj.weight"]
del state[f"model.layers.{src}.self_attn.q_proj.weight"]
del state[f"model.layers.{src}.self_attn.v_proj.weight"]
del state[f"model.layers.{src}.self_attn.o_proj.weight"]
del state[f"model.layers.{src}.post_attention_layernorm.weight"]
del state[f"model.layers.{src}.mlp.down_proj.weight"]
del state[f"model.layers.{src}.mlp.gate_proj.weight"]
del state[f"model.layers.{src}.mlp.up_proj.weight"]
for k, v in remapped_dict.items():
state[k] = v
"""
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists