This commit is contained in:
mrq 2024-12-21 22:52:10 -06:00
parent 2542ed067d
commit 353e478e68
7 changed files with 268 additions and 50 deletions

52
scripts/hf_test.py Normal file
View File

@ -0,0 +1,52 @@
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
# tokenizer = LlamaTokenizer.from_pretrained("./training/llama-encodec-ar+nar-len/hf/")
model = LlamaForCausalLM.from_pretrained("./training/llama-encodec-ar+nar-len/hf/")
phns = [1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2]
proms = [
[780,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,464,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,398,869,639,705,869,917,705,893,215,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,408,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,917,238,745,904,904,904,106,133,1019,1017,1017,395,883,87,519,594,1002,682,996,540,186,1019,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,25,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,159,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,62,172,638,359,562,138,839,846,775,556,688,1006,917,297,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,538,519,762,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,310,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,368,359,519,996,616,464,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,876,876,1017,63,329]
]
sep = [17685]
rvq_lvl = [17666]
lang = [17686]
len_seq = [17674]
for i, t in enumerate( proms[0] ):
proms[0][i] = t + 256 + 1024
ids = torch.tensor(phns + sep + lang + sep + rvq_lvl + sep + proms[0] + sep + len_seq, device="cuda", dtype=torch.int32)
pos_ids = torch.tensor( [*range(len(phns)+1)] + [*range(2)] + [*range(2)] + [*range(len(proms[0])+1)] + [0], device="cuda", dtype=torch.int32)
start = 17674 # 8448
end = start + 10 # 1025
with torch.no_grad():
original_lm_head = model.lm_head.weight
model.lm_head = torch.nn.Linear(1024, end - start, bias=False)
model.lm_head.weight.copy_(original_lm_head[start:end])
model.to(device="cuda", dtype=torch.float16)
model.eval()
n_decoded = 0
while True:
out = model(input_ids=ids.unsqueeze(0), position_ids=pos_ids.unsqueeze(0))
#logits = out.logits[0, -1:, start:end]
logits = out.logits[0, -1:, :]
tokens = logits.argmax(dim=-1)
n_decoded += 1
print( n_decoded, tokens )
if end in tokens or n_decoded > 5:
break
ids = torch.concat( [ ids, tokens + start ] )
pos_ids = torch.concat( [ pos_ids, torch.tensor([n_decoded]).to(pos_ids) ] )
print( out )
print( ids )

View File

@ -10,7 +10,9 @@ Populate `./include/` with the `llama.cpp` and `encodec.cpp` headers.
Populate `./libs/` with the compiled libraries of `llama.cpp` and `encodec.cpp`.
* `encodec.cpp` requires updating `ggml` to the latest version and doing a quick hack to make it work on the CPU backend.
* `llama.cpp` currently requires no hacks, but would be *very* nice to hack in a way to retrieve a model's `tok_embd`.
* `llama.cpp` currently requires no hacks, but:
* would be *very* nice to retrieve a model's `tok_embd` through the API.
* would be ***very*** nice to only specify a slice of the output head through the API.
Run `make`.

View File

@ -250,18 +250,19 @@ std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector
}
const int EMBEDDING_MODE_PROM = 0;
const int EMBEDDING_MODE_RESP_AR_NAR = 0;
const int EMBEDDING_MODE_RESP_NAR_LEN = 0;
const int EMBEDDING_MODE_RESP_AR_NAR = 1;
const int EMBEDDING_MODE_RESP_NAR_LEN = 2;
const int INFERENCE_MODE_LEN = 0;
const int INFERENCE_MODE_AR = 1;
const int INFERENCE_MODE_NAR_DEMASK = 2;
const int INFERENCE_MODE_NAR = 4;
const int INFERENCE_MODE_NAR = 3;
const int MODALITY_AR_NAR = 0;
const int MODALITY_NAR_LEN = 0;
const int MODALITY_NAR_LEN = 1;
const int MAX_DURATION = 75; // * 12;
const int CTX_SIZE = 2048;
// sums embeddings over a 2D "tensor"
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) {
@ -457,14 +458,14 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
int main(int argc, char ** argv) {
// to-do: replace all of this with proper loading code
int32_t ngl = 0;
int modality = MODALITY_AR_NAR;
int modality = MODALITY_NAR_LEN;
input_t input{};
embeddings_t embeddings_map{};
// input.phonemes = "hˈɛloː ʋˈɔrlt";
input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
std::string vall_e_model_path = "./data/vall_e-F16.gguf";
std::string vall_e_model_path = "./data/vall_e-f16.gguf";
std::string encodec_model_path = "./data/encodec.bin";
std::string input_prompt_path = "./data/prom.wav";
std::string output_response_path = "./data/resp.wav";
@ -497,9 +498,9 @@ int main(int argc, char ** argv) {
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 22500;
ctx_params.n_batch = 22500;
ctx_params.n_ubatch = 22500;
ctx_params.n_ctx = CTX_SIZE;
ctx_params.n_batch = CTX_SIZE;
ctx_params.n_ubatch = CTX_SIZE;
ctx_params.no_perf = false;
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
@ -519,7 +520,7 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(20));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(0.9, 20));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0));
// llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
@ -542,13 +543,13 @@ int main(int argc, char ** argv) {
if ( input.phonemes != "" ) {
const int n_prompt = -llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), NULL, 0, true, true);
// allocate space for the tokens and tokenize the input.phonemes
input.phns.resize(n_prompt)
if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phns.data(), input.phns.size(), true, true) < 0) {
input.phn.resize(n_prompt);
if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phn.data(), input.phn.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize: %s\n", __func__, input.phonemes.c_str());
return 1;
}
for ( auto& token : input.phns ) printf("%i ", token );
for ( auto& token : input.phn ) printf("%i ", token );
printf("\n");
}

View File

@ -114,6 +114,7 @@ def load_engines(training=True, **model_kwargs):
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07

View File

@ -71,20 +71,25 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
"stt",
]
classifier_bias = False
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
embedding = torch.nn.Embedding( n_tokens, model_dim )
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
if not split_classifiers:
classifier.weight[:] = state_dict['module']['classifier.weight'][:]
# to-do: ignore classifier for RVQ level 7
# inject text tokens
token_start = 0
token_end = l_tokens[0]
embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight']
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
if split_classifiers:
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
# tokenizer already has these tokens
# inject prom tokens
@ -104,9 +109,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_start = token_end
token_end += l_tokens[2] // 2
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight']
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
if split_classifiers:
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<AR|0:0|STOP|>'] = token_start + 1024
@ -115,9 +121,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_start = token_end
token_end += l_tokens[2] // 2
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
if classifier_bias:
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
if split_classifiers:
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
if classifier_bias:
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|NAR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<|NAR|0:0|STOP|>'] = token_start + 1024
@ -129,9 +136,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
start = token_start + ((l-1) * n_audio_tokens)
end = start + n_audio_tokens
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
if classifier_bias:
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
if split_classifiers:
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
if classifier_bias:
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
@ -147,9 +155,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_start = token_end
token_end += l_tokens[5]
embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight']
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
if split_classifiers:
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
for t in range(n_len_tokens):
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
@ -197,7 +206,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
out_dir = cfg.rel_path / "hf"
out_dir.mkdir(parents=True, exist_ok=True)
# write weights
torch_save( model_dict, out_dir / "model.safetensors" )
torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" )
# write tokenizer.json
tokenizer['model']['vocab'] |= tokenizer_vocab
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)

View File

@ -55,6 +55,38 @@ task_outputs = {
"len": "len",
}
# yuck
def _get_offsets():
return {
"text": 0, # <unk>
"quant_level": 17666, # <|RVQ:0>
"len": 17674, # <|len:0|>
"lang": 17686, # <|lang:en|>"
"task": 17692, # <|task:tts|>
"sep": 17685, # <|sep|>
"prom": [
256 + (1024 * 0), # <|P|0:0|>
256 + (1024 * 1), # <|P|1:0|>
256 + (1024 * 2), # <|P|2:0|>
256 + (1024 * 3), # <|P|3:0|>
256 + (1024 * 4), # <|P|4:0|>
256 + (1024 * 5), # <|P|5:0|>
256 + (1024 * 6), # <|P|6:0|>
256 + (1024 * 7), # <|P|7:0|>
],
"resp": [
8448, # <|AR|0:0|>
9473, # <|NAR|0:0|>
10498 + (1024 * 0), # <|NAR|0:1|>
10498 + (1024 * 1), # <|NAR|1:2|>
10498 + (1024 * 2), # <|NAR|2:3|>
10498 + (1024 * 3), # <|NAR|3:4|>
10498 + (1024 * 4), # <|NAR|4:5|>
10498 + (1024 * 5), # <|NAR|5:6|>
10498 + (1024 * 6), # <|NAR|6:7|>
]
}
def _dropout_mask( input, p=None ):
# cosine scheduling
if p is None:
@ -494,6 +526,9 @@ class Base(nn.Module):
classifier_l_tokens += [ 11 ]
classifier_l_names += ["len"]
n_vocab = 17701 if not split_classifiers else n_resp_tokens + 1
self.n_vocab = n_vocab
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
@ -601,7 +636,7 @@ class Base(nn.Module):
elif self.arch_type in ["mistral", "mixtral"]:
if n_experts <= 1:
self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens,
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=d_model*4,
@ -647,7 +682,7 @@ class Base(nn.Module):
if n_experts <= 1:
self.model = LlamaClass(LlamaConfig(
vocab_size=n_resp_tokens,
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=d_model*4,
@ -700,7 +735,7 @@ class Base(nn.Module):
))
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_resp_tokens,
vocab_size=n_vocab,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
@ -732,7 +767,7 @@ class Base(nn.Module):
self.model = RetNetDecoder(RetNetConfig(**kwargs))
elif self.arch_type in ["mamba2"]:
self.model = Mamba2Model(Mamba2Config(
vocab_size=n_resp_tokens,
vocab_size=n_vocab,
hidden_size=d_model,
expand=2,
num_hidden_layers=n_layers*2,
@ -744,7 +779,7 @@ class Base(nn.Module):
))
elif self.arch_type in ["mamba"]:
self.model = MambaModel(MambaConfig(
vocab_size=n_resp_tokens,
vocab_size=n_vocab,
hidden_size=d_model,
expand=2,
num_hidden_layers=n_layers*2,
@ -761,11 +796,11 @@ class Base(nn.Module):
del self.model.embeddings
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_resp_tokens)
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
self.classifiers = None
self.accuracy_metric = MulticlassAccuracy(
n_resp_tokens,
n_vocab,
top_k=10,
average="micro",
multidim_average="global",
@ -773,7 +808,7 @@ class Base(nn.Module):
)
self.precision_metric = MulticlassPrecision(
n_resp_tokens,
n_vocab,
top_k=10,
average="micro",
multidim_average="global",
@ -1031,6 +1066,48 @@ class Base(nn.Module):
raise Exception(f'Unrecognized task: {task_type}')
return inputs
def offset_inputs(
self,
inputs: list,
direction: int = 1, # -1 to de-offset
):
offsets = _get_offsets()
for batch_index, batch_input in enumerate(inputs):
quant_level = None
classifier_level = None
# pre-iterate
for name, input in batch_input:
if name == "quant_level":
quant_level = input
elif name == "classifier_level":
classifier_level = input
for name, input in batch_input:
if name not in offsets:
continue
if not isinstance( input, torch.Tensor ):
continue
offset = offsets[name]
if name in ["prom", "resp"]:
l = quant_level
if name == "resp":
if classifier_level == "AR:0:0":
l = 0
elif classifier_level == "NAR:0:0":
l = 1
else:
l = 2 + (quant_level-1)
offset = offset[l]
for i, t in enumerate( input ):
input[i] += offset * direction
return inputs
def inputs_to_embeddings(
self,
inputs: list,
@ -1366,6 +1443,49 @@ class Base(nn.Module):
if not isinstance(token, torch.Tensor):
continue
# offset to flattened vocab ranges
if self.classifier is not None:
offsets = _get_offsets()
if name in offsets:
offset = offsets[name]
# yes there's a better way
if name == "prom":
offset = offset[quant_level]
elif name == "resp":
"""
if classifier_level == "AR:0:0":
offset = offset[0]
elif classifier_level == "NAR:0:0":
offset = offset[1]
elif classifier_level == "NAR:0:1":
offset = offset[2]
elif classifier_level == "NAR:1:2":
offset = offset[3]
elif classifier_level == "NAR:2:3":
offset = offset[4]
elif classifier_level == "NAR:3:4":
offset = offset[5]
elif classifier_level == "NAR:4:5":
offset = offset[6]
elif classifier_level == "NAR:5:6":
offset = offset[7]
elif classifier_level == "NAR:6:7":
offset = offset[8]
else:
continue
"""
if classifier_level == "AR:0:0":
offset = offset[0]
elif classifier_level == "NAR:0:0":
offset = offset[1]
else:
offset = offset[2 + (quant_level-1)]
for i, t in enumerate( token ):
if t == self.ignore_index:
continue
token[i] += offset
if token.is_floating_point():
ignored = True
@ -1422,7 +1542,7 @@ class Base(nn.Module):
# perofrm loss calculation on the entire sequence
if not self.config.loss_factors:
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
target = _join( target, torch.tensor(self.ignore_index if self.classifier is None else 17685, device=target[-1].device) )
logit = logits[batch_index]
# shift if causal
@ -1606,6 +1726,37 @@ class Base(nn.Module):
self.loss = None
self.stats = None
# de-offset if needed
if self.classifier is not None:
offsets = _get_offsets()
for batch_index, classifier_level in enumerate( classifier_levels ):
# yes there's a better way
if classifier_level == "len":
offset = offsets["len"], 11
elif classifier_level == "AR:0:0":
offset = offsets["resp"][0], 1025
elif classifier_level == "NAR:0:0":
offset = offsets["resp"][1], 1024
elif classifier_level == "NAR:0:1":
offset = offsets["resp"][2], 1024
elif classifier_level == "NAR:1:2":
offset = offsets["resp"][3], 1024
elif classifier_level == "NAR:2:3":
offset = offsets["resp"][4], 1024
elif classifier_level == "NAR:3:4":
offset = offsets["resp"][5], 1024
elif classifier_level == "NAR:4:5":
offset = offsets["resp"][6], 1024
elif classifier_level == "NAR:5:6":
offset = offsets["resp"][7], 1024
elif classifier_level == "NAR:6:7":
offset = offsets["resp"][8], 1024
else:
continue
logits[batch_index] = logits[batch_index][offset[0]:offset[0]+offset[1], :]
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )

View File

@ -60,14 +60,13 @@ def is_dict_of( d, t ):
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
metadata = None
metadata = {}
# is a state_dict, no need to coerce
if is_dict_of( data, torch.Tensor ):
return data, metadata
# is maybe a dict with a state dict + metadata, coerce it
metadata = {}
target = module_key
if not target:
for k, v in data.items():
@ -78,7 +77,8 @@ def state_dict_to_tensor_metadata( data: dict, module_key=None ):
# not a dict of tensors, put it as metadata
try:
metadata[k] = json.dumps(v)
metadata[k] = json_stringify(v) if any([isinstance( v, dict ), isinstance( v, list )]) else v
if isinstance( metadata[k], bytes ):
metadata[k] = metadata[k].decode('utf-8')
except Exception as e:
@ -96,6 +96,9 @@ def torch_save( data, path, module_key=None ):
if ext in [".safetensor", ".safetensors", ".sft"]:
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
if metadata is None:
metadata = {}
return sft_save( data, path, metadata )
return torch.save( data, path )
@ -112,13 +115,12 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
if load_metadata:
metadata = f.metadata()
if metadata is not None:
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
return state_dict