cringe code to convert to LlamaForCausalLM-happy weights + tokenizer dict (still need to write logic to actually use these weights for proper inferencing)
This commit is contained in:
parent
84a05acb6d
commit
31ab90d84a
|
@ -70,7 +70,7 @@ For the most part, the model is complete. With the `NAR-len` being crammed on, I
|
|||
|
||||
However, while this solution boasts being lightweight, there are some caveats for its given size
|
||||
* its at capacity on what it *can* do without additional tasks to augment it further
|
||||
* post-fixing it with additional layers glued on doesn't seem to offer very much work (12 => 16 layers)
|
||||
* post-fixing it with additional layers glued on doesn't seem to offer very much improvement (12 => 16 layers)
|
||||
* wrangling it is a bit of a chore, as some voices work fine under the `AR` but not the `NAR-len`, and vice-versa
|
||||
* some voices outright refuse to work without LoRA training
|
||||
* some sampler settings works on some voices, but others need some tweaking
|
||||
|
|
171
vall_e/export.py
171
vall_e/export.py
|
@ -10,7 +10,177 @@ from .models.lora import lora_get_state_dict
|
|||
from .utils.io import torch_save, torch_load
|
||||
|
||||
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
||||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||
@torch.no_grad()
|
||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||
# to-do: infer all of this from the existing state_dict, should be easy by checking shape
|
||||
model_dim = 1024
|
||||
|
||||
n_text_tokens = 256
|
||||
n_audio_tokens = 1024
|
||||
n_resp_levels = 8
|
||||
n_len_tokens = 11
|
||||
n_lang_tokens = 4
|
||||
n_task_tokens = 9
|
||||
|
||||
# the new tokenizer to use
|
||||
tokenizer_append = {}
|
||||
|
||||
l_tokens = [
|
||||
n_text_tokens, # text
|
||||
n_audio_tokens * n_resp_levels, # prom
|
||||
(n_audio_tokens + 1) * 2, # resp: AR + NAR-len (with stop/mask)
|
||||
(n_audio_tokens) * (n_resp_levels - 1), # NAR
|
||||
n_resp_levels, # RVQ level
|
||||
n_len_tokens, # len tokens
|
||||
1, # separator
|
||||
n_lang_tokens, # langs
|
||||
n_task_tokens, # tasks
|
||||
]
|
||||
|
||||
n_tokens = sum(l_tokens)
|
||||
|
||||
lang_map = [
|
||||
"en",
|
||||
"ja",
|
||||
"de",
|
||||
"fr",
|
||||
]
|
||||
task_map = [
|
||||
"tts",
|
||||
"tts-c",
|
||||
"ns",
|
||||
"sr",
|
||||
"tse",
|
||||
"soe",
|
||||
"mask",
|
||||
"eoe",
|
||||
"stt",
|
||||
]
|
||||
|
||||
embedding = torch.nn.Embedding( n_tokens, model_dim )
|
||||
classifier = torch.nn.Linear( model_dim, n_tokens )
|
||||
|
||||
#embedding.weight.requires_grad = False
|
||||
#classifier.weight.requires_grad = False
|
||||
#classifier.bias.requires_grad = False
|
||||
|
||||
# inject text tokens
|
||||
token_start = 0
|
||||
token_end = l_tokens[0]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
|
||||
# tokenizer already has these tokens
|
||||
|
||||
# inject prom tokens
|
||||
token_start = token_end
|
||||
token_end += l_tokens[1]
|
||||
for l in range(n_resp_levels):
|
||||
start = token_start + (l*n_resp_levels)
|
||||
end = start + n_audio_tokens
|
||||
embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight']
|
||||
# there's no corresponding classifier
|
||||
#classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
#classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<P:{l}:{t}>'] = start + t
|
||||
|
||||
# inject AR
|
||||
token_start = token_end
|
||||
token_end += l_tokens[2] // 2
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<AR:0:0:{t}>'] = token_start + t
|
||||
tokenizer_append[f'<AR:0:0:STOP>'] = token_start + 1024
|
||||
|
||||
# inject NAR-len
|
||||
token_start = token_end
|
||||
token_end += l_tokens[2] // 2
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.8.weight']
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.8.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
|
||||
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024
|
||||
|
||||
# inject NAR
|
||||
token_start = token_end
|
||||
token_end += l_tokens[3]
|
||||
for l in range(1, n_resp_levels):
|
||||
start = token_start + ((l-1)*n_resp_levels)
|
||||
end = start + n_audio_tokens
|
||||
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
|
||||
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<NAR:{l-1}:{l}:{t}>'] = start + t
|
||||
|
||||
# inject RVQ level
|
||||
token_start = token_end
|
||||
token_end += l_tokens[4]
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight']
|
||||
# there is no corresponding classifier
|
||||
for l in range(n_resp_levels):
|
||||
tokenizer_append[f'<RVQ:{l}>'] = token_start + l
|
||||
|
||||
# inject len
|
||||
token_start = token_end
|
||||
token_end += l_tokens[5]
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
|
||||
for t in range(n_len_tokens):
|
||||
tokenizer_append[f'<len:{t}>'] = token_start + t
|
||||
|
||||
# inject sep
|
||||
token_start = token_end
|
||||
token_end += l_tokens[6]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['sep']
|
||||
tokenizer_append['<sep>'] = token_start
|
||||
# there is no corresponding classifier
|
||||
|
||||
# inject langs
|
||||
token_start = token_end
|
||||
token_end += l_tokens[7]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight']
|
||||
for l in range(n_lang_tokens):
|
||||
lang = lang_map[l]
|
||||
tokenizer_append[f'<lang:{lang}>'] = token_start + l
|
||||
# there is no corresponding classifier
|
||||
|
||||
# inject tasks
|
||||
token_start = token_end
|
||||
token_end += l_tokens[8]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight']
|
||||
for l in range(n_task_tokens):
|
||||
task = task_map[l]
|
||||
tokenizer_append[f'<task:{task}>'] = token_start + l
|
||||
# there is no corresponding classifier
|
||||
|
||||
|
||||
model_dict = {}
|
||||
# filter out the underlying model weights and extract them
|
||||
for k in state_dict['module'].keys():
|
||||
if not k.startswith('model.'):
|
||||
continue
|
||||
model_dict[k] = state_dict['module'][k].clone()
|
||||
del state_dict['module']
|
||||
|
||||
embedding_dict = embedding.state_dict()
|
||||
classifier_dict = classifier.state_dict()
|
||||
model_dict['model.embed_tokens.weight'] = embedding_dict['weight']
|
||||
model_dict['lm_head.weight'] = classifier_dict['weight']
|
||||
model_dict['lm_head.bias'] = classifier_dict['bias']
|
||||
|
||||
state_dict['module'] = model_dict
|
||||
state_dict['vocab'] = tokenizer_append
|
||||
|
||||
return state_dict
|
||||
|
||||
"""
|
||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||
token_dim = 1024
|
||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
||||
|
@ -59,6 +229,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
del state_dict['module']['classifier.bias']
|
||||
|
||||
return state_dict
|
||||
"""
|
||||
|
||||
# yanks a LoRA from the training checkpoint
|
||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user