added export option to convert Llama to MixtralMoE for another dumb experiment
This commit is contained in:
parent
3a65cc4b22
commit
10aaf840e7
|
@ -9,7 +9,7 @@ from .config import cfg
|
|||
from .models.lora import lora_get_state_dict
|
||||
from .utils.io import torch_save, torch_load
|
||||
|
||||
# stitches embeddings into one embedding & classifier => lm_head
|
||||
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||
token_dim = 1024
|
||||
|
@ -52,12 +52,15 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
|
||||
|
||||
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
||||
# to-do: properly recreate the output head weights or something
|
||||
state_dict['module']['lm_head.weight'] = out_proj
|
||||
|
||||
del state_dict['module']['classifier.weight']
|
||||
del state_dict['module']['classifier.bias']
|
||||
|
||||
return state_dict
|
||||
|
||||
# yanks a LoRA from the training checkpoint
|
||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
|
@ -87,12 +90,12 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
|||
|
||||
return state_dict
|
||||
|
||||
# copies a single classifier head into multiple classifier heads per RVQ level
|
||||
def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dtype = None):
|
||||
levels = config.max_levels
|
||||
|
||||
if "classifier.weight" not in state_dict['module']:
|
||||
return state_dict
|
||||
|
||||
# copy to new AudioClassifier
|
||||
for i in range(levels):
|
||||
tokens = 1025 if i == 0 else 1024
|
||||
|
@ -107,12 +110,31 @@ def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dt
|
|||
|
||||
return state_dict
|
||||
|
||||
# converts a normal LLaMA model to a MoE model, as best as I can
|
||||
def moe_ify( state_dict, config = cfg.model, save_path = None, dtype = None ):
|
||||
# to-do: find a good way to pass in requested experts
|
||||
experts = 8
|
||||
for layer in range( config.layers ):
|
||||
#state_dict[f'model.layers.{layer}.block_sparse_moe.gate.weight'] = torch.randn((config.dim, experts))
|
||||
for expert in range( experts ):
|
||||
state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w1.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.up_proj.weight'].clone()
|
||||
state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w2.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.down_proj.weight'].clone()
|
||||
state_dict['module'][f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w3.weight'] = state_dict['module'][f'model.layers.{layer}.mlp.gate_proj.weight'].clone()
|
||||
|
||||
del state_dict['module'][f'model.layers.{layer}.mlp.up_proj.weight']
|
||||
del state_dict['module'][f'model.layers.{layer}.mlp.down_proj.weight']
|
||||
del state_dict['module'][f'model.layers.{layer}.mlp.gate_proj.weight']
|
||||
|
||||
return state_dict
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--experts", type=int, default=8) # set target dtype to export to
|
||||
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
@ -123,13 +145,6 @@ def main():
|
|||
if args.module_only:
|
||||
cfg.trainer.load_module_only = True
|
||||
|
||||
callback = None
|
||||
if args.hf:
|
||||
callback = convert_to_hf
|
||||
elif args.lora:
|
||||
callback = extract_lora
|
||||
elif args.split_classifiers:
|
||||
callback = split_classifier_heads
|
||||
|
||||
if args.hf and args.lora:
|
||||
raise Exception("Requesting more than one callback")
|
||||
|
@ -141,6 +156,22 @@ def main():
|
|||
cfg.inference.backend = cfg.trainer.backend
|
||||
|
||||
engines = load_engines(training=False) # to ignore loading optimizer state
|
||||
|
||||
callback = None
|
||||
if args.hf:
|
||||
callback = convert_to_hf
|
||||
elif args.lora:
|
||||
callback = extract_lora
|
||||
elif args.split_classifiers:
|
||||
callback = split_classifier_heads
|
||||
elif args.moe_ify:
|
||||
callback = moe_ify
|
||||
# set it here after the model loads to not influence which model loads
|
||||
cfg.model.experts = args.experts
|
||||
for name, engine in engines.items():
|
||||
engine.module.config.experts = args.experts
|
||||
engine.hyper_config.experts = args.experts
|
||||
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback, format=args.format)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -375,10 +375,10 @@ def example_usage():
|
|||
'n_text_tokens': 256,
|
||||
'n_audio_tokens': 1024,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'd_model': 256, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1,
|
||||
'n_experts': 1 if not cfg.model else cfg.model.experts,
|
||||
|
||||
'p_dropout': 0.1,
|
||||
|
||||
|
@ -468,6 +468,8 @@ def example_usage():
|
|||
engine = Engine(model=model, optimizer=optimizer)
|
||||
engines = Engines({"ar+nar": engine})
|
||||
engines.setup()
|
||||
|
||||
print( model.state_dict().keys() )
|
||||
|
||||
"""
|
||||
if cfg.optimizations.model_offloading:
|
||||
|
|
|
@ -44,7 +44,7 @@ except Exception as e:
|
|||
pass
|
||||
|
||||
try:
|
||||
from .mixtral import MixtralModel, MixtralConfig
|
||||
from .mixtral import MixtralModel, MixtralConfig, load_balancing_loss_func
|
||||
AVAILABLE_ARCHES.append("mixtral")
|
||||
except Exception as e:
|
||||
ERROR_ARCHES["mixtral"] = e
|
||||
|
|
Loading…
Reference in New Issue
Block a user