corrected export.py's --hf

This commit is contained in:
mrq 2024-12-20 15:17:13 -06:00
parent 59bf6b8b33
commit d85273609e
3 changed files with 80 additions and 84 deletions

View File

@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True):
return text
@cache
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ):
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ):
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
sentences = [
"The birch canoe slid on the smooth planks.",

View File

@ -7,7 +7,7 @@ from .data import get_phone_symmap
from .engines import load_engines
from .config import cfg
from .models.lora import lora_get_state_dict
from .utils.io import torch_save, torch_load
from .utils.io import torch_save, torch_load, json_read, json_write, Path
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
@ -22,7 +22,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
# the new tokenizer to use
tokenizer_append = {}
tokenizer = {}
tokenizer_vocab = {}
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
if not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
if tokenizer_path.exists():
tokenizer = json_read( tokenizer_path )
else:
tokenizer = {
"model": {
"vocab": get_phone_symmap()
}
}
l_tokens = [
n_text_tokens, # text
@ -77,14 +90,14 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_start = token_end
token_end += l_tokens[1]
for l in range(n_resp_levels):
start = token_start + (l*n_resp_levels)
start = token_start + (l * n_audio_tokens)
end = start + n_audio_tokens
embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight']
# there's no corresponding classifier
#classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
#classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
for t in range(n_audio_tokens):
tokenizer_append[f'<P:{l}:{t}>'] = start + t
tokenizer_vocab[f'<|P|{l}:{t}|>'] = start + t
# inject AR
token_start = token_end
@ -93,8 +106,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
for t in range(n_audio_tokens):
tokenizer_append[f'<AR:0:0:{t}>'] = token_start + t
tokenizer_append[f'<AR:0:0:STOP>'] = token_start + 1024
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<AR|0:0|STOP|>'] = token_start + 1024
# inject NAR-len
token_start = token_end
@ -103,20 +116,20 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
for t in range(n_audio_tokens):
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024
tokenizer_vocab[f'<NAR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<NAR|0:0|STOP|>'] = token_start + 1024
# inject NAR
token_start = token_end
token_end += l_tokens[3]
for l in range(1, n_resp_levels):
start = token_start + ((l-1)*n_resp_levels)
start = token_start + ((l-1) * n_audio_tokens)
end = start + n_audio_tokens
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
for t in range(n_audio_tokens):
tokenizer_append[f'<NAR:{l-1}:{l}:{t}>'] = start + t
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
# inject RVQ level
token_start = token_end
@ -124,7 +137,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight']
# there is no corresponding classifier
for l in range(n_resp_levels):
tokenizer_append[f'<RVQ:{l}>'] = token_start + l
tokenizer_vocab[f'<|RVQ:{l}|>'] = token_start + l
# inject len
token_start = token_end
@ -133,13 +146,13 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
for t in range(n_len_tokens):
tokenizer_append[f'<len:{t}>'] = token_start + t
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
# inject sep
token_start = token_end
token_end += l_tokens[6]
embedding.weight[token_start:token_end] = state_dict['module']['sep']
tokenizer_append['<sep>'] = token_start
tokenizer_vocab['<|sep|>'] = token_start
# there is no corresponding classifier
# inject langs
@ -148,7 +161,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight']
for l in range(n_lang_tokens):
lang = lang_map[l]
tokenizer_append[f'<lang:{lang}>'] = token_start + l
tokenizer_vocab[f'<|lang:{lang}|>'] = token_start + l
# there is no corresponding classifier
# inject tasks
@ -157,7 +170,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight']
for l in range(n_task_tokens):
task = task_map[l]
tokenizer_append[f'<task:{task}>'] = token_start + l
tokenizer_vocab[f'<|task:{task}|>'] = token_start + l
# there is no corresponding classifier
@ -167,7 +180,6 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
if not k.startswith('model.'):
continue
model_dict[k] = state_dict['module'][k].clone()
del state_dict['module']
embedding_dict = embedding.state_dict()
classifier_dict = classifier.state_dict()
@ -175,62 +187,46 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
model_dict['lm_head.weight'] = classifier_dict['weight']
model_dict['lm_head.bias'] = classifier_dict['bias']
state_dict['module'] = model_dict
state_dict['vocab'] = tokenizer_append
# write files in an HF compatible way
out_dir = cfg.rel_path / "hf"
out_dir.mkdir(parents=True, exist_ok=True)
# write weights
torch_save( model_dict, out_dir / "model.safetensors" )
# write vocab.json
tokenizer['model']['vocab'] |= tokenizer_vocab
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
# write config.json
json_write({
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": model_dim,
"initializer_range": 0.02,
"intermediate_size": model_dim * 4,
"max_position_embeddings": 75 * 60 * 5,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 12,
"num_key_value_heads": 16,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0",
"use_cache": False,
"vocab_size": n_tokens
}, out_dir / "config.json", pretty=True )
return state_dict
"""
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
token_dim = 1024
embedding = torch.nn.Embedding(n_tokens, token_dim)
embedding.weight.requires_grad = False
def move_value(k):
v = state_dict['module'][k]
del state_dict['module'][k]
return v
separator = move_value('sep')
out_proj = move_value('classifier.weight')
text_emb = move_value('text_emb.weight')
langs_emb = move_value('langs_emb.weight')
tasks_emb = move_value('tasks_emb.weight')
tones_emb = move_value('tones_emb.weight')
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
start = 0
for i in range(256):
embedding.weight[start + i] = text_emb[i]
start = 256
for layer in range(8):
for i in range(1024):
offset = start + 1024 * layer
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
start = 256 + 1024 * 8
for layer in range(8):
for i in range(1024):
offset = start + 1024 * layer
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
# to-do: properly recreate the output head weights or something
state_dict['module']['lm_head.weight'] = out_proj
del state_dict['module']['classifier.weight']
del state_dict['module']['classifier.bias']
return state_dict
"""
# yanks a LoRA from the training checkpoint
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:

View File

@ -13,13 +13,12 @@ except:
from .utils import truncate_json
def json_stringify( data, truncate=False, pretty=False ):
def json_stringify( data, truncate=False, pretty=False, raw=False ):
if truncate:
return truncate_json( json.dumps( data ) )
if pretty:
if use_orjson:
return json.dumps( data, option=json.OPT_INDENT_2 ).decode('utf-8')
return json.dumps( data, indent='\t' ).decode('utf-8')
s = json.dumps( data, option=json.OPT_INDENT_2 ) if use_orjson else json.dumps( data, indent='\t' )
return s if raw and use_orjson else s.decode('utf-8')
return json.dumps( data )
def json_parse( string ):
@ -34,11 +33,11 @@ def json_read( path, default=None ):
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
return json_parse( f.read() )
def json_write( data, path, truncate=False ):
def json_write( data, path, **kwargs ):
path = coerce_path( path )
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
f.write( json_stringify( data, truncate=truncate ) )
f.write( json_stringify( data, raw=use_orjson, **kwargs ) )
def coerce_path( path ):
return path if isinstance( path, Path ) else Path(path)
@ -94,7 +93,7 @@ def torch_save( data, path, module_key=None ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
if ext in [".safetensor", ".safetensors", ".sft"]:
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
return sft_save( data, path, metadata )
@ -105,7 +104,7 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
if ext in [".safetensor", ".safetensors", ".sft"]:
state_dict = {}
with sft_load(path, framework=framework, device=device) as f:
for k in f.keys():
@ -113,12 +112,13 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
if load_metadata:
metadata = f.metadata()
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
if metadata is not None:
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
return state_dict