From 10aaf840e76bba605002ffdc1ae598a791bb5939 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 4 Aug 2024 20:25:06 -0500 Subject: [PATCH] added export option to convert Llama to MixtralMoE for another dumb experiment --- vall_e/export.py | 49 +++++++++++++++++++++++++++------- vall_e/models/ar_nar.py | 6 +++-- vall_e/models/arch/__init__.py | 2 +- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/vall_e/export.py b/vall_e/export.py index 934393a..2e68a56 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -9,7 +9,7 @@ from .config import cfg from .models.lora import lora_get_state_dict from .utils.io import torch_save, torch_load -# stitches embeddings into one embedding & classifier => lm_head +# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight def convert_to_hf( state_dict, config = None, save_path = None ): n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1 token_dim = 1024 @@ -52,12 +52,15 @@ def convert_to_hf( state_dict, config = None, save_path = None ): embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer] state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict() + # to-do: properly recreate the output head weights or something state_dict['module']['lm_head.weight'] = out_proj + del state_dict['module']['classifier.weight'] del state_dict['module']['classifier.bias'] return state_dict +# yanks a LoRA from the training checkpoint def extract_lora( state_dict, config = None, save_path = None, dtype = None ): if dtype is None: dtype = cfg.inference.dtype @@ -87,12 +90,12 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ): return state_dict +# copies a single classifier head into multiple classifier heads per RVQ level def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dtype = None): levels = config.max_levels if "classifier.weight" not in state_dict['module']: return state_dict - # copy to new AudioClassifier for i in range(levels): tokens = 1025 if i == 0 else 1024 @@ -107,12 +110,31 @@ def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dt return state_dict +# converts a normal LLaMA model to a MoE model, as best as I can +def moe_ify( state_dict, config = cfg.model, save_path = None, dtype = None ): + # to-do: find a good way to pass in requested experts + experts = 8 + for layer in range( config.layers ): + #state_dict[f'model.layers.{layer}.block_sparse_moe.gate.weight'] = torch.randn((config.dim, experts)) + for expert in range( experts ): + state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w1.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.up_proj.weight'].clone() + state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w2.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.down_proj.weight'].clone() + state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w3.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.gate_proj.weight'].clone() + + del state_dict['module'][f'model.layers.{layer}.mlp.up_proj.weight'] + del state_dict['module'][f'model.layers.{layer}.mlp.down_proj.weight'] + del state_dict['module'][f'model.layers.{layer}.mlp.gate_proj.weight'] + + return state_dict + def main(): parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--module-only", action='store_true') parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--lora", action='store_true', default=None) # exports LoRA parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads + parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads + parser.add_argument("--experts", type=int, default=8) # set target dtype to export to parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to parser.add_argument("--format", type=str, default="pth") # set target format to export weights under args, unknown = parser.parse_known_args() @@ -123,13 +145,6 @@ def main(): if args.module_only: cfg.trainer.load_module_only = True - callback = None - if args.hf: - callback = convert_to_hf - elif args.lora: - callback = extract_lora - elif args.split_classifiers: - callback = split_classifier_heads if args.hf and args.lora: raise Exception("Requesting more than one callback") @@ -141,6 +156,22 @@ def main(): cfg.inference.backend = cfg.trainer.backend engines = load_engines(training=False) # to ignore loading optimizer state + + callback = None + if args.hf: + callback = convert_to_hf + elif args.lora: + callback = extract_lora + elif args.split_classifiers: + callback = split_classifier_heads + elif args.moe_ify: + callback = moe_ify + # set it here after the model loads to not influence which model loads + cfg.model.experts = args.experts + for name, engine in engines.items(): + engine.module.config.experts = args.experts + engine.hyper_config.experts = args.experts + engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback, format=args.format) if __name__ == "__main__": diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4b5c404..3193ab8 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -375,10 +375,10 @@ def example_usage(): 'n_text_tokens': 256, 'n_audio_tokens': 1024, - 'd_model': 1024, # 256, # 1024, # 1536 + 'd_model': 256, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 - 'n_experts': 1, + 'n_experts': 1 if not cfg.model else cfg.model.experts, 'p_dropout': 0.1, @@ -468,6 +468,8 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) engines = Engines({"ar+nar": engine}) engines.setup() + + print( model.state_dict().keys() ) """ if cfg.optimizations.model_offloading: diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index 4ce99e9..d7774a7 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -44,7 +44,7 @@ except Exception as e: pass try: - from .mixtral import MixtralModel, MixtralConfig + from .mixtral import MixtralModel, MixtralConfig, load_balancing_loss_func AVAILABLE_ARCHES.append("mixtral") except Exception as e: ERROR_ARCHES["mixtral"] = e