corrected export.py's --hf
This commit is contained in:
parent
59bf6b8b33
commit
d85273609e
|
@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ):
|
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ):
|
||||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||||
sentences = [
|
sentences = [
|
||||||
"The birch canoe slid on the smooth planks.",
|
"The birch canoe slid on the smooth planks.",
|
||||||
|
|
134
vall_e/export.py
134
vall_e/export.py
|
@ -7,7 +7,7 @@ from .data import get_phone_symmap
|
||||||
from .engines import load_engines
|
from .engines import load_engines
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .models.lora import lora_get_state_dict
|
from .models.lora import lora_get_state_dict
|
||||||
from .utils.io import torch_save, torch_load
|
from .utils.io import torch_save, torch_load, json_read, json_write, Path
|
||||||
|
|
||||||
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
||||||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||||
|
@ -22,7 +22,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
||||||
|
|
||||||
# the new tokenizer to use
|
# the new tokenizer to use
|
||||||
tokenizer_append = {}
|
tokenizer = {}
|
||||||
|
tokenizer_vocab = {}
|
||||||
|
|
||||||
|
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
|
||||||
|
if not tokenizer_path.exists():
|
||||||
|
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||||
|
if tokenizer_path.exists():
|
||||||
|
tokenizer = json_read( tokenizer_path )
|
||||||
|
else:
|
||||||
|
tokenizer = {
|
||||||
|
"model": {
|
||||||
|
"vocab": get_phone_symmap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
l_tokens = [
|
l_tokens = [
|
||||||
n_text_tokens, # text
|
n_text_tokens, # text
|
||||||
|
@ -77,14 +90,14 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
token_end += l_tokens[1]
|
token_end += l_tokens[1]
|
||||||
for l in range(n_resp_levels):
|
for l in range(n_resp_levels):
|
||||||
start = token_start + (l*n_resp_levels)
|
start = token_start + (l * n_audio_tokens)
|
||||||
end = start + n_audio_tokens
|
end = start + n_audio_tokens
|
||||||
embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight']
|
embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight']
|
||||||
# there's no corresponding classifier
|
# there's no corresponding classifier
|
||||||
#classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
#classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||||
#classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
#classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||||
for t in range(n_audio_tokens):
|
for t in range(n_audio_tokens):
|
||||||
tokenizer_append[f'<P:{l}:{t}>'] = start + t
|
tokenizer_vocab[f'<|P|{l}:{t}|>'] = start + t
|
||||||
|
|
||||||
# inject AR
|
# inject AR
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
|
@ -93,8 +106,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
|
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
|
||||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
|
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
|
||||||
for t in range(n_audio_tokens):
|
for t in range(n_audio_tokens):
|
||||||
tokenizer_append[f'<AR:0:0:{t}>'] = token_start + t
|
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
|
||||||
tokenizer_append[f'<AR:0:0:STOP>'] = token_start + 1024
|
tokenizer_vocab[f'<AR|0:0|STOP|>'] = token_start + 1024
|
||||||
|
|
||||||
# inject NAR-len
|
# inject NAR-len
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
|
@ -103,20 +116,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
||||||
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
||||||
for t in range(n_audio_tokens):
|
for t in range(n_audio_tokens):
|
||||||
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
|
tokenizer_vocab[f'<NAR|0:0|{t}|>'] = token_start + t
|
||||||
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024
|
tokenizer_vocab[f'<NAR|0:0|STOP|>'] = token_start + 1024
|
||||||
|
|
||||||
# inject NAR
|
# inject NAR
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
token_end += l_tokens[3]
|
token_end += l_tokens[3]
|
||||||
for l in range(1, n_resp_levels):
|
for l in range(1, n_resp_levels):
|
||||||
start = token_start + ((l-1)*n_resp_levels)
|
start = token_start + ((l-1) * n_audio_tokens)
|
||||||
end = start + n_audio_tokens
|
end = start + n_audio_tokens
|
||||||
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
|
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
|
||||||
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||||
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||||
for t in range(n_audio_tokens):
|
for t in range(n_audio_tokens):
|
||||||
tokenizer_append[f'<NAR:{l-1}:{l}:{t}>'] = start + t
|
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
|
||||||
|
|
||||||
# inject RVQ level
|
# inject RVQ level
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
|
@ -124,7 +137,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight']
|
embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight']
|
||||||
# there is no corresponding classifier
|
# there is no corresponding classifier
|
||||||
for l in range(n_resp_levels):
|
for l in range(n_resp_levels):
|
||||||
tokenizer_append[f'<RVQ:{l}>'] = token_start + l
|
tokenizer_vocab[f'<|RVQ:{l}|>'] = token_start + l
|
||||||
|
|
||||||
# inject len
|
# inject len
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
|
@ -133,13 +146,13 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
|
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
|
||||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
|
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
|
||||||
for t in range(n_len_tokens):
|
for t in range(n_len_tokens):
|
||||||
tokenizer_append[f'<len:{t}>'] = token_start + t
|
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
|
||||||
|
|
||||||
# inject sep
|
# inject sep
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
token_end += l_tokens[6]
|
token_end += l_tokens[6]
|
||||||
embedding.weight[token_start:token_end] = state_dict['module']['sep']
|
embedding.weight[token_start:token_end] = state_dict['module']['sep']
|
||||||
tokenizer_append['<sep>'] = token_start
|
tokenizer_vocab['<|sep|>'] = token_start
|
||||||
# there is no corresponding classifier
|
# there is no corresponding classifier
|
||||||
|
|
||||||
# inject langs
|
# inject langs
|
||||||
|
@ -148,7 +161,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight']
|
embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight']
|
||||||
for l in range(n_lang_tokens):
|
for l in range(n_lang_tokens):
|
||||||
lang = lang_map[l]
|
lang = lang_map[l]
|
||||||
tokenizer_append[f'<lang:{lang}>'] = token_start + l
|
tokenizer_vocab[f'<|lang:{lang}|>'] = token_start + l
|
||||||
# there is no corresponding classifier
|
# there is no corresponding classifier
|
||||||
|
|
||||||
# inject tasks
|
# inject tasks
|
||||||
|
@ -157,7 +170,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight']
|
embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight']
|
||||||
for l in range(n_task_tokens):
|
for l in range(n_task_tokens):
|
||||||
task = task_map[l]
|
task = task_map[l]
|
||||||
tokenizer_append[f'<task:{task}>'] = token_start + l
|
tokenizer_vocab[f'<|task:{task}|>'] = token_start + l
|
||||||
# there is no corresponding classifier
|
# there is no corresponding classifier
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,7 +180,6 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
if not k.startswith('model.'):
|
if not k.startswith('model.'):
|
||||||
continue
|
continue
|
||||||
model_dict[k] = state_dict['module'][k].clone()
|
model_dict[k] = state_dict['module'][k].clone()
|
||||||
del state_dict['module']
|
|
||||||
|
|
||||||
embedding_dict = embedding.state_dict()
|
embedding_dict = embedding.state_dict()
|
||||||
classifier_dict = classifier.state_dict()
|
classifier_dict = classifier.state_dict()
|
||||||
|
@ -175,62 +187,46 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
model_dict['lm_head.weight'] = classifier_dict['weight']
|
model_dict['lm_head.weight'] = classifier_dict['weight']
|
||||||
model_dict['lm_head.bias'] = classifier_dict['bias']
|
model_dict['lm_head.bias'] = classifier_dict['bias']
|
||||||
|
|
||||||
state_dict['module'] = model_dict
|
# write files in an HF compatible way
|
||||||
state_dict['vocab'] = tokenizer_append
|
out_dir = cfg.rel_path / "hf"
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# write weights
|
||||||
|
torch_save( model_dict, out_dir / "model.safetensors" )
|
||||||
|
# write vocab.json
|
||||||
|
tokenizer['model']['vocab'] |= tokenizer_vocab
|
||||||
|
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
|
||||||
|
# write config.json
|
||||||
|
json_write({
|
||||||
|
"architectures": [
|
||||||
|
"LlamaForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": model_dim,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": model_dim * 4,
|
||||||
|
"max_position_embeddings": 75 * 60 * 5,
|
||||||
|
"model_type": "llama",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"num_key_value_heads": 16,
|
||||||
|
"pretraining_tp": 1,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": None,
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.40.0",
|
||||||
|
"use_cache": False,
|
||||||
|
"vocab_size": n_tokens
|
||||||
|
}, out_dir / "config.json", pretty=True )
|
||||||
|
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
"""
|
|
||||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
|
||||||
token_dim = 1024
|
|
||||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
|
||||||
embedding.weight.requires_grad = False
|
|
||||||
|
|
||||||
def move_value(k):
|
|
||||||
v = state_dict['module'][k]
|
|
||||||
del state_dict['module'][k]
|
|
||||||
return v
|
|
||||||
|
|
||||||
separator = move_value('sep')
|
|
||||||
out_proj = move_value('classifier.weight')
|
|
||||||
text_emb = move_value('text_emb.weight')
|
|
||||||
langs_emb = move_value('langs_emb.weight')
|
|
||||||
tasks_emb = move_value('tasks_emb.weight')
|
|
||||||
tones_emb = move_value('tones_emb.weight')
|
|
||||||
|
|
||||||
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
|
||||||
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
|
||||||
|
|
||||||
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
|
|
||||||
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
|
|
||||||
|
|
||||||
|
|
||||||
start = 0
|
|
||||||
for i in range(256):
|
|
||||||
embedding.weight[start + i] = text_emb[i]
|
|
||||||
|
|
||||||
start = 256
|
|
||||||
for layer in range(8):
|
|
||||||
for i in range(1024):
|
|
||||||
offset = start + 1024 * layer
|
|
||||||
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
|
|
||||||
|
|
||||||
start = 256 + 1024 * 8
|
|
||||||
for layer in range(8):
|
|
||||||
for i in range(1024):
|
|
||||||
offset = start + 1024 * layer
|
|
||||||
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
|
|
||||||
|
|
||||||
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
|
||||||
# to-do: properly recreate the output head weights or something
|
|
||||||
state_dict['module']['lm_head.weight'] = out_proj
|
|
||||||
|
|
||||||
del state_dict['module']['classifier.weight']
|
|
||||||
del state_dict['module']['classifier.bias']
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
"""
|
|
||||||
|
|
||||||
# yanks a LoRA from the training checkpoint
|
# yanks a LoRA from the training checkpoint
|
||||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
|
|
|
@ -13,13 +13,12 @@ except:
|
||||||
|
|
||||||
from .utils import truncate_json
|
from .utils import truncate_json
|
||||||
|
|
||||||
def json_stringify( data, truncate=False, pretty=False ):
|
def json_stringify( data, truncate=False, pretty=False, raw=False ):
|
||||||
if truncate:
|
if truncate:
|
||||||
return truncate_json( json.dumps( data ) )
|
return truncate_json( json.dumps( data ) )
|
||||||
if pretty:
|
if pretty:
|
||||||
if use_orjson:
|
s = json.dumps( data, option=json.OPT_INDENT_2 ) if use_orjson else json.dumps( data, indent='\t' )
|
||||||
return json.dumps( data, option=json.OPT_INDENT_2 ).decode('utf-8')
|
return s if raw and use_orjson else s.decode('utf-8')
|
||||||
return json.dumps( data, indent='\t' ).decode('utf-8')
|
|
||||||
return json.dumps( data )
|
return json.dumps( data )
|
||||||
|
|
||||||
def json_parse( string ):
|
def json_parse( string ):
|
||||||
|
@ -34,11 +33,11 @@ def json_read( path, default=None ):
|
||||||
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
|
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
|
||||||
return json_parse( f.read() )
|
return json_parse( f.read() )
|
||||||
|
|
||||||
def json_write( data, path, truncate=False ):
|
def json_write( data, path, **kwargs ):
|
||||||
path = coerce_path( path )
|
path = coerce_path( path )
|
||||||
|
|
||||||
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
|
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
|
||||||
f.write( json_stringify( data, truncate=truncate ) )
|
f.write( json_stringify( data, raw=use_orjson, **kwargs ) )
|
||||||
|
|
||||||
def coerce_path( path ):
|
def coerce_path( path ):
|
||||||
return path if isinstance( path, Path ) else Path(path)
|
return path if isinstance( path, Path ) else Path(path)
|
||||||
|
@ -94,7 +93,7 @@ def torch_save( data, path, module_key=None ):
|
||||||
path = coerce_path(path)
|
path = coerce_path(path)
|
||||||
ext = path.suffix
|
ext = path.suffix
|
||||||
|
|
||||||
if ext in [".safetensor", ".sft"]:
|
if ext in [".safetensor", ".safetensors", ".sft"]:
|
||||||
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
|
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
|
||||||
|
|
||||||
return sft_save( data, path, metadata )
|
return sft_save( data, path, metadata )
|
||||||
|
@ -105,7 +104,7 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
|
||||||
path = coerce_path(path)
|
path = coerce_path(path)
|
||||||
ext = path.suffix
|
ext = path.suffix
|
||||||
|
|
||||||
if ext in [".safetensor", ".sft"]:
|
if ext in [".safetensor", ".safetensors", ".sft"]:
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
with sft_load(path, framework=framework, device=device) as f:
|
with sft_load(path, framework=framework, device=device) as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
|
@ -113,6 +112,7 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
|
||||||
|
|
||||||
if load_metadata:
|
if load_metadata:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
|
if metadata is not None:
|
||||||
for k, v in metadata.items():
|
for k, v in metadata.items():
|
||||||
try:
|
try:
|
||||||
metadata[k] = json.loads( v )
|
metadata[k] = json.loads( v )
|
||||||
|
|
Loading…
Reference in New Issue
Block a user