added export option to convert Llama to MixtralMoE for another dumb experiment

This commit is contained in:
mrq 2024-08-04 20:25:06 -05:00
parent 3a65cc4b22
commit 10aaf840e7
3 changed files with 45 additions and 12 deletions

View File

@ -9,7 +9,7 @@ from .config import cfg
from .models.lora import lora_get_state_dict from .models.lora import lora_get_state_dict
from .utils.io import torch_save, torch_load from .utils.io import torch_save, torch_load
# stitches embeddings into one embedding & classifier => lm_head # stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
def convert_to_hf( state_dict, config = None, save_path = None ): def convert_to_hf( state_dict, config = None, save_path = None ):
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1 n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
token_dim = 1024 token_dim = 1024
@ -52,12 +52,15 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer] embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict() state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
# to-do: properly recreate the output head weights or something
state_dict['module']['lm_head.weight'] = out_proj state_dict['module']['lm_head.weight'] = out_proj
del state_dict['module']['classifier.weight']
del state_dict['module']['classifier.bias'] del state_dict['module']['classifier.bias']
return state_dict return state_dict
# yanks a LoRA from the training checkpoint
def extract_lora( state_dict, config = None, save_path = None, dtype = None ): def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None: if dtype is None:
dtype = cfg.inference.dtype dtype = cfg.inference.dtype
@ -87,12 +90,12 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
return state_dict return state_dict
# copies a single classifier head into multiple classifier heads per RVQ level
def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dtype = None): def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dtype = None):
levels = config.max_levels levels = config.max_levels
if "classifier.weight" not in state_dict['module']: if "classifier.weight" not in state_dict['module']:
return state_dict return state_dict
# copy to new AudioClassifier # copy to new AudioClassifier
for i in range(levels): for i in range(levels):
tokens = 1025 if i == 0 else 1024 tokens = 1025 if i == 0 else 1024
@ -107,12 +110,31 @@ def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dt
return state_dict return state_dict
# converts a normal LLaMA model to a MoE model, as best as I can
def moe_ify( state_dict, config = cfg.model, save_path = None, dtype = None ):
# to-do: find a good way to pass in requested experts
experts = 8
for layer in range( config.layers ):
#state_dict[f'model.layers.{layer}.block_sparse_moe.gate.weight'] = torch.randn((config.dim, experts))
for expert in range( experts ):
state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w1.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.up_proj.weight'].clone()
state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w2.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.down_proj.weight'].clone()
state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w3.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.gate_proj.weight'].clone()
del state_dict['module'][f'model.layers.{layer}.mlp.up_proj.weight']
del state_dict['module'][f'model.layers.{layer}.mlp.down_proj.weight']
del state_dict['module'][f'model.layers.{layer}.mlp.gate_proj.weight']
return state_dict
def main(): def main():
parser = argparse.ArgumentParser("Save trained model to path.") parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true') parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
parser.add_argument("--experts", type=int, default=8) # set target dtype to export to
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
@ -123,13 +145,6 @@ def main():
if args.module_only: if args.module_only:
cfg.trainer.load_module_only = True cfg.trainer.load_module_only = True
callback = None
if args.hf:
callback = convert_to_hf
elif args.lora:
callback = extract_lora
elif args.split_classifiers:
callback = split_classifier_heads
if args.hf and args.lora: if args.hf and args.lora:
raise Exception("Requesting more than one callback") raise Exception("Requesting more than one callback")
@ -141,6 +156,22 @@ def main():
cfg.inference.backend = cfg.trainer.backend cfg.inference.backend = cfg.trainer.backend
engines = load_engines(training=False) # to ignore loading optimizer state engines = load_engines(training=False) # to ignore loading optimizer state
callback = None
if args.hf:
callback = convert_to_hf
elif args.lora:
callback = extract_lora
elif args.split_classifiers:
callback = split_classifier_heads
elif args.moe_ify:
callback = moe_ify
# set it here after the model loads to not influence which model loads
cfg.model.experts = args.experts
for name, engine in engines.items():
engine.module.config.experts = args.experts
engine.hyper_config.experts = args.experts
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback, format=args.format) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback, format=args.format)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -375,10 +375,10 @@ def example_usage():
'n_text_tokens': 256, 'n_text_tokens': 256,
'n_audio_tokens': 1024, 'n_audio_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536 'd_model': 256, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24 'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32 'n_layers': 12, # 32
'n_experts': 1, 'n_experts': 1 if not cfg.model else cfg.model.experts,
'p_dropout': 0.1, 'p_dropout': 0.1,
@ -468,6 +468,8 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
engines = Engines({"ar+nar": engine}) engines = Engines({"ar+nar": engine})
engines.setup() engines.setup()
print( model.state_dict().keys() )
""" """
if cfg.optimizations.model_offloading: if cfg.optimizations.model_offloading:

View File

@ -44,7 +44,7 @@ except Exception as e:
pass pass
try: try:
from .mixtral import MixtralModel, MixtralConfig from .mixtral import MixtralModel, MixtralConfig, load_balancing_loss_func
AVAILABLE_ARCHES.append("mixtral") AVAILABLE_ARCHES.append("mixtral")
except Exception as e: except Exception as e:
ERROR_ARCHES["mixtral"] = e ERROR_ARCHES["mixtral"] = e