From 3fc0540f492de59ba04c448a6d589b9f33666114 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 21 Nov 2024 15:07:46 -0600 Subject: [PATCH] m --- vall_e/config.py | 4 +++- vall_e/engines/__init__.py | 43 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/vall_e/config.py b/vall_e/config.py index 9d3d91a..1fedcf9 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -326,7 +326,7 @@ class Model: if isinstance(self.size, dict): if hasattr(self.size, "label") and self.size['label']: name.append(f"{self.size['label']}") - elif isinstance(self.size, str) and self.size != "full": + elif isinstance(self.size, str) and self.size not in ["full","extended"]: name.append(self.size) if self.experts > 1: @@ -392,6 +392,8 @@ class Model: if self.size == "double": return 24 + if self.size == "extended": + return 16 return 12 @property diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2dde826..f6cbd99 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -231,6 +231,49 @@ def load_engines(training=True, **model_kwargs): del state['resps_emb.embeddings.8.weight'] """ + """ + if True: + remapped_dict = {} + remapped_indices = [ + (0, 1), + (1, 2), + (2, 3), + (3, 5), + (4, 6), + (5, 7), + (6, 9), + (7, 10), + (8, 11), + (9, 13), + (10, 14), + (11, 15), + ] + + for src, dst in remapped_indices: + remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"] + remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"] + remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"] + remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"] + remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"] + + del state[f"model.layers.{src}.input_layernorm.weight"] + del state[f"model.layers.{src}.self_attn.k_proj.weight"] + del state[f"model.layers.{src}.self_attn.q_proj.weight"] + del state[f"model.layers.{src}.self_attn.v_proj.weight"] + del state[f"model.layers.{src}.self_attn.o_proj.weight"] + del state[f"model.layers.{src}.post_attention_layernorm.weight"] + del state[f"model.layers.{src}.mlp.down_proj.weight"] + del state[f"model.layers.{src}.mlp.gate_proj.weight"] + del state[f"model.layers.{src}.mlp.up_proj.weight"] + + for k, v in remapped_dict.items(): + state[k] = v + """ + model.load_state_dict(state, strict=cfg.trainer.strict_loading) # load lora weights if exists