diff --git a/docs/README.md b/docs/README.md index 713e11c..ab1271e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -70,7 +70,7 @@ For the most part, the model is complete. With the `NAR-len` being crammed on, I However, while this solution boasts being lightweight, there are some caveats for its given size * its at capacity on what it *can* do without additional tasks to augment it further - * post-fixing it with additional layers glued on doesn't seem to offer very much work (12 => 16 layers) + * post-fixing it with additional layers glued on doesn't seem to offer very much improvement (12 => 16 layers) * wrangling it is a bit of a chore, as some voices work fine under the `AR` but not the `NAR-len`, and vice-versa * some voices outright refuse to work without LoRA training * some sampler settings works on some voices, but others need some tweaking diff --git a/vall_e/export.py b/vall_e/export.py index c7318b1..a0212b2 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -10,7 +10,177 @@ from .models.lora import lora_get_state_dict from .utils.io import torch_save, torch_load # stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight +# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed +@torch.no_grad() def convert_to_hf( state_dict, config = None, save_path = None ): + # to-do: infer all of this from the existing state_dict, should be easy by checking shape + model_dim = 1024 + + n_text_tokens = 256 + n_audio_tokens = 1024 + n_resp_levels = 8 + n_len_tokens = 11 + n_lang_tokens = 4 + n_task_tokens = 9 + + # the new tokenizer to use + tokenizer_append = {} + + l_tokens = [ + n_text_tokens, # text + n_audio_tokens * n_resp_levels, # prom + (n_audio_tokens + 1) * 2, # resp: AR + NAR-len (with stop/mask) + (n_audio_tokens) * (n_resp_levels - 1), # NAR + n_resp_levels, # RVQ level + n_len_tokens, # len tokens + 1, # separator + n_lang_tokens, # langs + n_task_tokens, # tasks + ] + + n_tokens = sum(l_tokens) + + lang_map = [ + "en", + "ja", + "de", + "fr", + ] + task_map = [ + "tts", + "tts-c", + "ns", + "sr", + "tse", + "soe", + "mask", + "eoe", + "stt", + ] + + embedding = torch.nn.Embedding( n_tokens, model_dim ) + classifier = torch.nn.Linear( model_dim, n_tokens ) + + #embedding.weight.requires_grad = False + #classifier.weight.requires_grad = False + #classifier.bias.requires_grad = False + + # inject text tokens + token_start = 0 + token_end = l_tokens[0] + embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight'] + classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight'] + classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias'] + # tokenizer already has these tokens + + # inject prom tokens + token_start = token_end + token_end += l_tokens[1] + for l in range(n_resp_levels): + start = token_start + (l*n_resp_levels) + end = start + n_audio_tokens + embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight'] + # there's no corresponding classifier + #classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] + #classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] + for t in range(n_audio_tokens): + tokenizer_append[f''] = start + t + + # inject AR + token_start = token_end + token_end += l_tokens[2] // 2 + embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight'] + classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight'] + classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias'] + for t in range(n_audio_tokens): + tokenizer_append[f''] = token_start + t + tokenizer_append[f''] = token_start + 1024 + + # inject NAR-len + token_start = token_end + token_end += l_tokens[2] // 2 + embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight'] + classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.8.weight'] + classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.8.bias'] + for t in range(n_audio_tokens): + tokenizer_append[f''] = token_start + t + tokenizer_append[f''] = token_start + 1024 + + # inject NAR + token_start = token_end + token_end += l_tokens[3] + for l in range(1, n_resp_levels): + start = token_start + ((l-1)*n_resp_levels) + end = start + n_audio_tokens + embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight'] + classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] + classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] + for t in range(n_audio_tokens): + tokenizer_append[f''] = start + t + + # inject RVQ level + token_start = token_end + token_end += l_tokens[4] + embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight'] + # there is no corresponding classifier + for l in range(n_resp_levels): + tokenizer_append[f''] = token_start + l + + # inject len + token_start = token_end + token_end += l_tokens[5] + embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight'] + classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256 + classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256 + for t in range(n_len_tokens): + tokenizer_append[f''] = token_start + t + + # inject sep + token_start = token_end + token_end += l_tokens[6] + embedding.weight[token_start:token_end] = state_dict['module']['sep'] + tokenizer_append[''] = token_start + # there is no corresponding classifier + + # inject langs + token_start = token_end + token_end += l_tokens[7] + embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight'] + for l in range(n_lang_tokens): + lang = lang_map[l] + tokenizer_append[f''] = token_start + l + # there is no corresponding classifier + + # inject tasks + token_start = token_end + token_end += l_tokens[8] + embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight'] + for l in range(n_task_tokens): + task = task_map[l] + tokenizer_append[f''] = token_start + l + # there is no corresponding classifier + + + model_dict = {} + # filter out the underlying model weights and extract them + for k in state_dict['module'].keys(): + if not k.startswith('model.'): + continue + model_dict[k] = state_dict['module'][k].clone() + del state_dict['module'] + + embedding_dict = embedding.state_dict() + classifier_dict = classifier.state_dict() + model_dict['model.embed_tokens.weight'] = embedding_dict['weight'] + model_dict['lm_head.weight'] = classifier_dict['weight'] + model_dict['lm_head.bias'] = classifier_dict['bias'] + + state_dict['module'] = model_dict + state_dict['vocab'] = tokenizer_append + + return state_dict + + """ n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1 token_dim = 1024 embedding = torch.nn.Embedding(n_tokens, token_dim) @@ -59,6 +229,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): del state_dict['module']['classifier.bias'] return state_dict + """ # yanks a LoRA from the training checkpoint def extract_lora( state_dict, config = None, save_path = None, dtype = None ):