diff --git a/vall_e/data.py b/vall_e/data.py index 79889bf..4637ba6 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True): return text @cache -def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ): +def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ): duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range sentences = [ "The birch canoe slid on the smooth planks.", diff --git a/vall_e/export.py b/vall_e/export.py index d8e43e0..9b046f3 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -7,7 +7,7 @@ from .data import get_phone_symmap from .engines import load_engines from .config import cfg from .models.lora import lora_get_state_dict -from .utils.io import torch_save, torch_load +from .utils.io import torch_save, torch_load, json_read, json_write, Path # stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight # *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed @@ -22,7 +22,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ): n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0] # the new tokenizer to use - tokenizer_append = {} + tokenizer = {} + tokenizer_vocab = {} + + tokenizer_path = cfg.rel_path / cfg.tokenizer_path + if not tokenizer_path.exists(): + tokenizer_path = Path("./data/") / cfg.tokenizer_path + if tokenizer_path.exists(): + tokenizer = json_read( tokenizer_path ) + else: + tokenizer = { + "model": { + "vocab": get_phone_symmap() + } + } l_tokens = [ n_text_tokens, # text @@ -77,14 +90,14 @@ def convert_to_hf( state_dict, config = None, save_path = None ): token_start = token_end token_end += l_tokens[1] for l in range(n_resp_levels): - start = token_start + (l*n_resp_levels) + start = token_start + (l * n_audio_tokens) end = start + n_audio_tokens embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight'] # there's no corresponding classifier #classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] #classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] for t in range(n_audio_tokens): - tokenizer_append[f''] = start + t + tokenizer_vocab[f'<|P|{l}:{t}|>'] = start + t # inject AR token_start = token_end @@ -93,8 +106,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ): classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight'] classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias'] for t in range(n_audio_tokens): - tokenizer_append[f''] = token_start + t - tokenizer_append[f''] = token_start + 1024 + tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t + tokenizer_vocab[f''] = token_start + 1024 # inject NAR-len token_start = token_end @@ -103,20 +116,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ): classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight'] classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias'] for t in range(n_audio_tokens): - tokenizer_append[f''] = token_start + t - tokenizer_append[f''] = token_start + 1024 + tokenizer_vocab[f''] = token_start + t + tokenizer_vocab[f''] = token_start + 1024 # inject NAR token_start = token_end token_end += l_tokens[3] for l in range(1, n_resp_levels): - start = token_start + ((l-1)*n_resp_levels) + start = token_start + ((l-1) * n_audio_tokens) end = start + n_audio_tokens embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight'] classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] for t in range(n_audio_tokens): - tokenizer_append[f''] = start + t + tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t # inject RVQ level token_start = token_end @@ -124,7 +137,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight'] # there is no corresponding classifier for l in range(n_resp_levels): - tokenizer_append[f''] = token_start + l + tokenizer_vocab[f'<|RVQ:{l}|>'] = token_start + l # inject len token_start = token_end @@ -133,13 +146,13 @@ def convert_to_hf( state_dict, config = None, save_path = None ): classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256 classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256 for t in range(n_len_tokens): - tokenizer_append[f''] = token_start + t + tokenizer_vocab[f'<|len:{t}|>'] = token_start + t # inject sep token_start = token_end token_end += l_tokens[6] embedding.weight[token_start:token_end] = state_dict['module']['sep'] - tokenizer_append[''] = token_start + tokenizer_vocab['<|sep|>'] = token_start # there is no corresponding classifier # inject langs @@ -148,7 +161,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight'] for l in range(n_lang_tokens): lang = lang_map[l] - tokenizer_append[f''] = token_start + l + tokenizer_vocab[f'<|lang:{lang}|>'] = token_start + l # there is no corresponding classifier # inject tasks @@ -157,7 +170,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight'] for l in range(n_task_tokens): task = task_map[l] - tokenizer_append[f''] = token_start + l + tokenizer_vocab[f'<|task:{task}|>'] = token_start + l # there is no corresponding classifier @@ -167,7 +180,6 @@ def convert_to_hf( state_dict, config = None, save_path = None ): if not k.startswith('model.'): continue model_dict[k] = state_dict['module'][k].clone() - del state_dict['module'] embedding_dict = embedding.state_dict() classifier_dict = classifier.state_dict() @@ -175,62 +187,46 @@ def convert_to_hf( state_dict, config = None, save_path = None ): model_dict['lm_head.weight'] = classifier_dict['weight'] model_dict['lm_head.bias'] = classifier_dict['bias'] - state_dict['module'] = model_dict - state_dict['vocab'] = tokenizer_append + # write files in an HF compatible way + out_dir = cfg.rel_path / "hf" + out_dir.mkdir(parents=True, exist_ok=True) + # write weights + torch_save( model_dict, out_dir / "model.safetensors" ) + # write vocab.json + tokenizer['model']['vocab'] |= tokenizer_vocab + json_write(tokenizer, out_dir / "tokenizer.json", pretty=True) + # write config.json + json_write({ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_size": model_dim, + "initializer_range": 0.02, + "intermediate_size": model_dim * 4, + "max_position_embeddings": 75 * 60 * 5, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 12, + "num_key_value_heads": 16, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0", + "use_cache": False, + "vocab_size": n_tokens + }, out_dir / "config.json", pretty=True ) + return state_dict - """ - n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1 - token_dim = 1024 - embedding = torch.nn.Embedding(n_tokens, token_dim) - embedding.weight.requires_grad = False - - def move_value(k): - v = state_dict['module'][k] - del state_dict['module'][k] - return v - - separator = move_value('sep') - out_proj = move_value('classifier.weight') - text_emb = move_value('text_emb.weight') - langs_emb = move_value('langs_emb.weight') - tasks_emb = move_value('tasks_emb.weight') - tones_emb = move_value('tones_emb.weight') - - proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ] - resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ] - - proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ] - resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ] - - - start = 0 - for i in range(256): - embedding.weight[start + i] = text_emb[i] - - start = 256 - for layer in range(8): - for i in range(1024): - offset = start + 1024 * layer - embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer] - - start = 256 + 1024 * 8 - for layer in range(8): - for i in range(1024): - offset = start + 1024 * layer - embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer] - - state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict() - # to-do: properly recreate the output head weights or something - state_dict['module']['lm_head.weight'] = out_proj - - del state_dict['module']['classifier.weight'] - del state_dict['module']['classifier.bias'] - - return state_dict - """ - # yanks a LoRA from the training checkpoint def extract_lora( state_dict, config = None, save_path = None, dtype = None ): if dtype is None: diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index 2a4e3c7..7ab73bb 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -13,13 +13,12 @@ except: from .utils import truncate_json -def json_stringify( data, truncate=False, pretty=False ): +def json_stringify( data, truncate=False, pretty=False, raw=False ): if truncate: return truncate_json( json.dumps( data ) ) if pretty: - if use_orjson: - return json.dumps( data, option=json.OPT_INDENT_2 ).decode('utf-8') - return json.dumps( data, indent='\t' ).decode('utf-8') + s = json.dumps( data, option=json.OPT_INDENT_2 ) if use_orjson else json.dumps( data, indent='\t' ) + return s if raw and use_orjson else s.decode('utf-8') return json.dumps( data ) def json_parse( string ): @@ -34,11 +33,11 @@ def json_read( path, default=None ): with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f: return json_parse( f.read() ) -def json_write( data, path, truncate=False ): +def json_write( data, path, **kwargs ): path = coerce_path( path ) with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f: - f.write( json_stringify( data, truncate=truncate ) ) + f.write( json_stringify( data, raw=use_orjson, **kwargs ) ) def coerce_path( path ): return path if isinstance( path, Path ) else Path(path) @@ -94,7 +93,7 @@ def torch_save( data, path, module_key=None ): path = coerce_path(path) ext = path.suffix - if ext in [".safetensor", ".sft"]: + if ext in [".safetensor", ".safetensors", ".sft"]: data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key ) return sft_save( data, path, metadata ) @@ -105,7 +104,7 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T path = coerce_path(path) ext = path.suffix - if ext in [".safetensor", ".sft"]: + if ext in [".safetensor", ".safetensors", ".sft"]: state_dict = {} with sft_load(path, framework=framework, device=device) as f: for k in f.keys(): @@ -113,12 +112,13 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T if load_metadata: metadata = f.metadata() - for k, v in metadata.items(): - try: - metadata[k] = json.loads( v ) - except Exception as e: - pass - state_dict = { module_key: state_dict } | metadata + if metadata is not None: + for k, v in metadata.items(): + try: + metadata[k] = json.loads( v ) + except Exception as e: + pass + state_dict = { module_key: state_dict } | metadata return state_dict