This commit is contained in:
mrq 2024-12-21 11:56:22 -06:00
parent 5788db849b
commit 979c1f797c
2 changed files with 72 additions and 42 deletions

View File

@ -15,7 +15,7 @@ Probably something like:
* [x] converted model to GGUF
* [ ] convert it without modifying any of the existing code
* [x] basic framework
* [x] load the model
* [x] load the quantized model
* [x] orchestrate the required embeddings
* [x] juggle the output head / classifier properly
* [ ] phonemize text
@ -23,8 +23,8 @@ Probably something like:
* [ ] load audio from disk
* [ ] encode audio
* [ ] sum embeddings for the `prom` and prior `resp`s
* [x] `AR` sampling
* [ ] `NAR-len` demasking sampling
* [ ] `NAR` sampling
* [ ] decode audio to disk
* [ ] a functional CLI
* [ ] quantize the model (properly)
* [ ] a functional CLI

View File

@ -168,7 +168,7 @@ int main(int argc, char ** argv) {
std::vector<std::vector<llama_token>> response_tokens = {
{922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665},
};
std::string model_path = "./vall_e/Vall_E-238M-F16.gguf";
std::string model_path = "./vall_e/Vall_E-238M-Q8_0.gguf";
// load dynamic backends
ggml_backend_load_all();
@ -207,48 +207,82 @@ int main(int argc, char ** argv) {
// prepare batch
auto n_embd = llama_n_embd( model );
auto n_vocab = llama_n_vocab( model );
// float* embd = (float*) llama_get_embedding_weights( model )->data;
float* embds = (float*) (model->tok_embd->data);
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
// grab input embeddings
std::vector<float> embds( n_embd * n_vocab );
auto* qtype = ggml_get_type_traits(model->tok_embd->type);
// dequantize if needed
if ( ggml_is_quantized(model->tok_embd->type) ) {
qtype->to_float(model->tok_embd->data, embds.data(), embds.size());
}
// to-do: derive these offsets from the tokenizer itself
float* text_embds = embds + (0 * n_embd); // <bos>
float* rvq_level_embd = embds + (17666 * n_embd); // <|RVQ:0>
float* len_embd = embds + (17674 * n_embd); // <|len:0|>
float* lang_embd = embds + (17686 * n_embd); // <|lang:en|>
float* task_embd = embds + (17692 * n_embd); // <|task:tts|>
float* sep_embd = embds + (17685 * n_embd); // <|sep|>
// to-do: clean this up, probably make it at parity to inputs_to_embeddings
int text_embd_start = 0; // <unk>
int rvq_level_embd_start = 17666; // <|RVQ:0>
int len_embd_start = 17674; // <|len:0|>
int lang_embd_start = 17686; // <|lang:en|>
int task_embd_start = 17692; // <|task:tts|>
int sep_embd_start = 17685; // <|sep|>
int prom_embd_start[] = {
256 + (1024 * 0), // <|P|0:0|>
256 + (1024 * 1), // <|P|1:0|>
256 + (1024 * 2), // <|P|2:0|>
256 + (1024 * 3), // <|P|3:0|>
256 + (1024 * 4), // <|P|4:0|>
256 + (1024 * 5), // <|P|5:0|>
256 + (1024 * 6), // <|P|6:0|>
256 + (1024 * 7), // <|P|7:0|>
};
int resp_embd_start[] = {
8448, // <|AR|0:0|>
9473, // <|NAR|0:0|>
10498 + (1024 * 0), // <|NAR|0:1|>
10498 + (1024 * 1), // <|NAR|1:2|>
10498 + (1024 * 2), // <|NAR|2:3|>
10498 + (1024 * 3), // <|NAR|3:4|>
10498 + (1024 * 4), // <|NAR|4:5|>
10498 + (1024 * 5), // <|NAR|5:6|>
10498 + (1024 * 6), // <|NAR|6:7|>
};
float* text_embds = &embds[text_embd_start * n_embd];
float* rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
float* len_embd = &embds[len_embd_start * n_embd];
float* lang_embd = &embds[lang_embd_start * n_embd];
float* task_embd = &embds[task_embd_start * n_embd];
float* sep_embd = &embds[sep_embd_start * n_embd];
float* prom_embds[] = {
embds + (256 + (1024 * 0) * n_embd), // <|P|0:0|>
embds + (256 + (1024 * 1) * n_embd), // <|P|1:0|>
embds + (256 + (1024 * 2) * n_embd), // <|P|2:0|>
embds + (256 + (1024 * 3) * n_embd), // <|P|3:0|>
embds + (256 + (1024 * 4) * n_embd), // <|P|4:0|>
embds + (256 + (1024 * 5) * n_embd), // <|P|5:0|>
embds + (256 + (1024 * 6) * n_embd), // <|P|6:0|>
embds + (256 + (1024 * 7) * n_embd), // <|P|7:0|>
&embds[prom_embd_start[0] * n_embd],
&embds[prom_embd_start[1] * n_embd],
&embds[prom_embd_start[2] * n_embd],
&embds[prom_embd_start[3] * n_embd],
&embds[prom_embd_start[4] * n_embd],
&embds[prom_embd_start[5] * n_embd],
&embds[prom_embd_start[6] * n_embd],
&embds[prom_embd_start[7] * n_embd],
};
float* resps_embds[] = {
embds + (8448 * n_embd), // <|AR|0:0|>
embds + (9473 * n_embd), // <|NAR|0:0|>
embds + (10498 + (1024 * 0) * n_embd), // <|NAR|0:1|>
embds + (10498 + (1024 * 1) * n_embd), // <|NAR|1:2|>
embds + (10498 + (1024 * 2) * n_embd), // <|NAR|2:3|>
embds + (10498 + (1024 * 3) * n_embd), // <|NAR|3:4|>
embds + (10498 + (1024 * 4) * n_embd), // <|NAR|4:5|>
embds + (10498 + (1024 * 5) * n_embd), // <|NAR|5:6|>
embds + (10498 + (1024 * 6) * n_embd), // <|NAR|6:7|>
&embds[resp_embd_start[0] * n_embd],
&embds[resp_embd_start[1] * n_embd],
&embds[resp_embd_start[2] * n_embd],
&embds[resp_embd_start[3] * n_embd],
&embds[resp_embd_start[4] * n_embd],
&embds[resp_embd_start[5] * n_embd],
&embds[resp_embd_start[6] * n_embd],
&embds[resp_embd_start[7] * n_embd],
&embds[resp_embd_start[8] * n_embd],
};
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
// insert into batch
{
// keeps track of the position for each sequence
size_t pos = 0;
// insert text tokens
for ( auto& id : phoneme_tokens ) {
batch_add( batch, id, n_embd, text_embds, pos++, false );
}
for ( auto& id : phoneme_tokens ) batch_add( batch, id, n_embd, text_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0;
// insert lang token
@ -262,18 +296,14 @@ int main(int argc, char ** argv) {
// insert prom tokens
// to-do: handle summing
for ( auto l = 0; l < prompt_tokens.size(); ++l ) {
for ( auto& id : prompt_tokens[l] ) {
batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
}
for ( auto& id : prompt_tokens[l] ) batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
}
batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar );
pos = 0;
// fill in masked tokens
if ( !is_ar ) {
for ( auto i = 0; i < response_tokens[0].size(); ++i ) {
batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
}
for ( auto i = 0; i < response_tokens[0].size(); ++i ) batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
}
pos = 0;
}
@ -293,7 +323,7 @@ int main(int argc, char ** argv) {
// align to AR's classifier
// to-do: derive from tokenizer
int range[] = { 8448, 8448 + 1024 }; // { <|AR|0:0|>, <|AR|0:STOP|> }
int range[] = { resp_embd_start[0], resp_embd_start[1] };
auto* logits = llama_get_logits_ith( ctx, -1 );
for ( auto i = 0; i < n_vocab; ++i ) {
if ( i < range[0] || i >= range[1] ) {
@ -305,7 +335,7 @@ int main(int argc, char ** argv) {
auto t = llama_sampler_sample(smpl, ctx, -1);
// is stop token
if ( t == 9472 ) { // <|AR|0:STOP|>
if ( t == resp_embd_start[1] - 1 ) { // <|AR|0:STOP|>
break;
}