corrected export.py's --hf
This commit is contained in:
parent
59bf6b8b33
commit
d85273609e
|
@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True):
|
|||
return text
|
||||
|
||||
@cache
|
||||
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ):
|
||||
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ):
|
||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||
sentences = [
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
|
|
134
vall_e/export.py
134
vall_e/export.py
|
@ -7,7 +7,7 @@ from .data import get_phone_symmap
|
|||
from .engines import load_engines
|
||||
from .config import cfg
|
||||
from .models.lora import lora_get_state_dict
|
||||
from .utils.io import torch_save, torch_load
|
||||
from .utils.io import torch_save, torch_load, json_read, json_write, Path
|
||||
|
||||
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
||||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||
|
@ -22,7 +22,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
||||
|
||||
# the new tokenizer to use
|
||||
tokenizer_append = {}
|
||||
tokenizer = {}
|
||||
tokenizer_vocab = {}
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
|
||||
if not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||
if tokenizer_path.exists():
|
||||
tokenizer = json_read( tokenizer_path )
|
||||
else:
|
||||
tokenizer = {
|
||||
"model": {
|
||||
"vocab": get_phone_symmap()
|
||||
}
|
||||
}
|
||||
|
||||
l_tokens = [
|
||||
n_text_tokens, # text
|
||||
|
@ -77,14 +90,14 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
token_start = token_end
|
||||
token_end += l_tokens[1]
|
||||
for l in range(n_resp_levels):
|
||||
start = token_start + (l*n_resp_levels)
|
||||
start = token_start + (l * n_audio_tokens)
|
||||
end = start + n_audio_tokens
|
||||
embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight']
|
||||
# there's no corresponding classifier
|
||||
#classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
#classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<P:{l}:{t}>'] = start + t
|
||||
tokenizer_vocab[f'<|P|{l}:{t}|>'] = start + t
|
||||
|
||||
# inject AR
|
||||
token_start = token_end
|
||||
|
@ -93,8 +106,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<AR:0:0:{t}>'] = token_start + t
|
||||
tokenizer_append[f'<AR:0:0:STOP>'] = token_start + 1024
|
||||
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
|
||||
tokenizer_vocab[f'<AR|0:0|STOP|>'] = token_start + 1024
|
||||
|
||||
# inject NAR-len
|
||||
token_start = token_end
|
||||
|
@ -103,20 +116,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
||||
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
|
||||
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024
|
||||
tokenizer_vocab[f'<NAR|0:0|{t}|>'] = token_start + t
|
||||
tokenizer_vocab[f'<NAR|0:0|STOP|>'] = token_start + 1024
|
||||
|
||||
# inject NAR
|
||||
token_start = token_end
|
||||
token_end += l_tokens[3]
|
||||
for l in range(1, n_resp_levels):
|
||||
start = token_start + ((l-1)*n_resp_levels)
|
||||
start = token_start + ((l-1) * n_audio_tokens)
|
||||
end = start + n_audio_tokens
|
||||
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
|
||||
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<NAR:{l-1}:{l}:{t}>'] = start + t
|
||||
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
|
||||
|
||||
# inject RVQ level
|
||||
token_start = token_end
|
||||
|
@ -124,7 +137,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight']
|
||||
# there is no corresponding classifier
|
||||
for l in range(n_resp_levels):
|
||||
tokenizer_append[f'<RVQ:{l}>'] = token_start + l
|
||||
tokenizer_vocab[f'<|RVQ:{l}|>'] = token_start + l
|
||||
|
||||
# inject len
|
||||
token_start = token_end
|
||||
|
@ -133,13 +146,13 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
|
||||
for t in range(n_len_tokens):
|
||||
tokenizer_append[f'<len:{t}>'] = token_start + t
|
||||
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
|
||||
|
||||
# inject sep
|
||||
token_start = token_end
|
||||
token_end += l_tokens[6]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['sep']
|
||||
tokenizer_append['<sep>'] = token_start
|
||||
tokenizer_vocab['<|sep|>'] = token_start
|
||||
# there is no corresponding classifier
|
||||
|
||||
# inject langs
|
||||
|
@ -148,7 +161,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight']
|
||||
for l in range(n_lang_tokens):
|
||||
lang = lang_map[l]
|
||||
tokenizer_append[f'<lang:{lang}>'] = token_start + l
|
||||
tokenizer_vocab[f'<|lang:{lang}|>'] = token_start + l
|
||||
# there is no corresponding classifier
|
||||
|
||||
# inject tasks
|
||||
|
@ -157,7 +170,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight']
|
||||
for l in range(n_task_tokens):
|
||||
task = task_map[l]
|
||||
tokenizer_append[f'<task:{task}>'] = token_start + l
|
||||
tokenizer_vocab[f'<|task:{task}|>'] = token_start + l
|
||||
# there is no corresponding classifier
|
||||
|
||||
|
||||
|
@ -167,7 +180,6 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
if not k.startswith('model.'):
|
||||
continue
|
||||
model_dict[k] = state_dict['module'][k].clone()
|
||||
del state_dict['module']
|
||||
|
||||
embedding_dict = embedding.state_dict()
|
||||
classifier_dict = classifier.state_dict()
|
||||
|
@ -175,62 +187,46 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
model_dict['lm_head.weight'] = classifier_dict['weight']
|
||||
model_dict['lm_head.bias'] = classifier_dict['bias']
|
||||
|
||||
state_dict['module'] = model_dict
|
||||
state_dict['vocab'] = tokenizer_append
|
||||
# write files in an HF compatible way
|
||||
out_dir = cfg.rel_path / "hf"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# write weights
|
||||
torch_save( model_dict, out_dir / "model.safetensors" )
|
||||
# write vocab.json
|
||||
tokenizer['model']['vocab'] |= tokenizer_vocab
|
||||
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
|
||||
# write config.json
|
||||
json_write({
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": model_dim,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": model_dim * 4,
|
||||
"max_position_embeddings": 75 * 60 * 5,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 12,
|
||||
"num_key_value_heads": 16,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 10000.0,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.0",
|
||||
"use_cache": False,
|
||||
"vocab_size": n_tokens
|
||||
}, out_dir / "config.json", pretty=True )
|
||||
|
||||
|
||||
return state_dict
|
||||
|
||||
"""
|
||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||
token_dim = 1024
|
||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
||||
embedding.weight.requires_grad = False
|
||||
|
||||
def move_value(k):
|
||||
v = state_dict['module'][k]
|
||||
del state_dict['module'][k]
|
||||
return v
|
||||
|
||||
separator = move_value('sep')
|
||||
out_proj = move_value('classifier.weight')
|
||||
text_emb = move_value('text_emb.weight')
|
||||
langs_emb = move_value('langs_emb.weight')
|
||||
tasks_emb = move_value('tasks_emb.weight')
|
||||
tones_emb = move_value('tones_emb.weight')
|
||||
|
||||
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||
|
||||
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||
|
||||
|
||||
start = 0
|
||||
for i in range(256):
|
||||
embedding.weight[start + i] = text_emb[i]
|
||||
|
||||
start = 256
|
||||
for layer in range(8):
|
||||
for i in range(1024):
|
||||
offset = start + 1024 * layer
|
||||
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
|
||||
|
||||
start = 256 + 1024 * 8
|
||||
for layer in range(8):
|
||||
for i in range(1024):
|
||||
offset = start + 1024 * layer
|
||||
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
|
||||
|
||||
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
||||
# to-do: properly recreate the output head weights or something
|
||||
state_dict['module']['lm_head.weight'] = out_proj
|
||||
|
||||
del state_dict['module']['classifier.weight']
|
||||
del state_dict['module']['classifier.bias']
|
||||
|
||||
return state_dict
|
||||
"""
|
||||
|
||||
# yanks a LoRA from the training checkpoint
|
||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||
if dtype is None:
|
||||
|
|
|
@ -13,13 +13,12 @@ except:
|
|||
|
||||
from .utils import truncate_json
|
||||
|
||||
def json_stringify( data, truncate=False, pretty=False ):
|
||||
def json_stringify( data, truncate=False, pretty=False, raw=False ):
|
||||
if truncate:
|
||||
return truncate_json( json.dumps( data ) )
|
||||
if pretty:
|
||||
if use_orjson:
|
||||
return json.dumps( data, option=json.OPT_INDENT_2 ).decode('utf-8')
|
||||
return json.dumps( data, indent='\t' ).decode('utf-8')
|
||||
s = json.dumps( data, option=json.OPT_INDENT_2 ) if use_orjson else json.dumps( data, indent='\t' )
|
||||
return s if raw and use_orjson else s.decode('utf-8')
|
||||
return json.dumps( data )
|
||||
|
||||
def json_parse( string ):
|
||||
|
@ -34,11 +33,11 @@ def json_read( path, default=None ):
|
|||
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
|
||||
return json_parse( f.read() )
|
||||
|
||||
def json_write( data, path, truncate=False ):
|
||||
def json_write( data, path, **kwargs ):
|
||||
path = coerce_path( path )
|
||||
|
||||
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
|
||||
f.write( json_stringify( data, truncate=truncate ) )
|
||||
f.write( json_stringify( data, raw=use_orjson, **kwargs ) )
|
||||
|
||||
def coerce_path( path ):
|
||||
return path if isinstance( path, Path ) else Path(path)
|
||||
|
@ -94,7 +93,7 @@ def torch_save( data, path, module_key=None ):
|
|||
path = coerce_path(path)
|
||||
ext = path.suffix
|
||||
|
||||
if ext in [".safetensor", ".sft"]:
|
||||
if ext in [".safetensor", ".safetensors", ".sft"]:
|
||||
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
|
||||
|
||||
return sft_save( data, path, metadata )
|
||||
|
@ -105,7 +104,7 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
|
|||
path = coerce_path(path)
|
||||
ext = path.suffix
|
||||
|
||||
if ext in [".safetensor", ".sft"]:
|
||||
if ext in [".safetensor", ".safetensors", ".sft"]:
|
||||
state_dict = {}
|
||||
with sft_load(path, framework=framework, device=device) as f:
|
||||
for k in f.keys():
|
||||
|
@ -113,12 +112,13 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
|
|||
|
||||
if load_metadata:
|
||||
metadata = f.metadata()
|
||||
for k, v in metadata.items():
|
||||
try:
|
||||
metadata[k] = json.loads( v )
|
||||
except Exception as e:
|
||||
pass
|
||||
state_dict = { module_key: state_dict } | metadata
|
||||
if metadata is not None:
|
||||
for k, v in metadata.items():
|
||||
try:
|
||||
metadata[k] = json.loads( v )
|
||||
except Exception as e:
|
||||
pass
|
||||
state_dict = { module_key: state_dict } | metadata
|
||||
|
||||
return state_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user