m
This commit is contained in:
parent
6845c447c9
commit
3fc0540f49
|
@ -326,7 +326,7 @@ class Model:
|
||||||
if isinstance(self.size, dict):
|
if isinstance(self.size, dict):
|
||||||
if hasattr(self.size, "label") and self.size['label']:
|
if hasattr(self.size, "label") and self.size['label']:
|
||||||
name.append(f"{self.size['label']}")
|
name.append(f"{self.size['label']}")
|
||||||
elif isinstance(self.size, str) and self.size != "full":
|
elif isinstance(self.size, str) and self.size not in ["full","extended"]:
|
||||||
name.append(self.size)
|
name.append(self.size)
|
||||||
|
|
||||||
if self.experts > 1:
|
if self.experts > 1:
|
||||||
|
@ -392,6 +392,8 @@ class Model:
|
||||||
|
|
||||||
if self.size == "double":
|
if self.size == "double":
|
||||||
return 24
|
return 24
|
||||||
|
if self.size == "extended":
|
||||||
|
return 16
|
||||||
return 12
|
return 12
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -231,6 +231,49 @@ def load_engines(training=True, **model_kwargs):
|
||||||
del state['resps_emb.embeddings.8.weight']
|
del state['resps_emb.embeddings.8.weight']
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
if True:
|
||||||
|
remapped_dict = {}
|
||||||
|
remapped_indices = [
|
||||||
|
(0, 1),
|
||||||
|
(1, 2),
|
||||||
|
(2, 3),
|
||||||
|
(3, 5),
|
||||||
|
(4, 6),
|
||||||
|
(5, 7),
|
||||||
|
(6, 9),
|
||||||
|
(7, 10),
|
||||||
|
(8, 11),
|
||||||
|
(9, 13),
|
||||||
|
(10, 14),
|
||||||
|
(11, 15),
|
||||||
|
]
|
||||||
|
|
||||||
|
for src, dst in remapped_indices:
|
||||||
|
remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
||||||
|
remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"]
|
||||||
|
|
||||||
|
del state[f"model.layers.{src}.input_layernorm.weight"]
|
||||||
|
del state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
||||||
|
del state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
||||||
|
del state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
||||||
|
del state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
||||||
|
del state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
||||||
|
del state[f"model.layers.{src}.mlp.down_proj.weight"]
|
||||||
|
del state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
||||||
|
del state[f"model.layers.{src}.mlp.up_proj.weight"]
|
||||||
|
|
||||||
|
for k, v in remapped_dict.items():
|
||||||
|
state[k] = v
|
||||||
|
"""
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# load lora weights if exists
|
# load lora weights if exists
|
||||||
|
|
Loading…
Reference in New Issue
Block a user