diff --git a/vall_e.cpp/README.md b/vall_e.cpp/README.md index 89bcdaa..cd8611e 100644 --- a/vall_e.cpp/README.md +++ b/vall_e.cpp/README.md @@ -15,7 +15,7 @@ Probably something like: * [x] converted model to GGUF * [ ] convert it without modifying any of the existing code * [x] basic framework - * [x] load the model + * [x] load the quantized model * [x] orchestrate the required embeddings * [x] juggle the output head / classifier properly * [ ] phonemize text @@ -23,8 +23,8 @@ Probably something like: * [ ] load audio from disk * [ ] encode audio * [ ] sum embeddings for the `prom` and prior `resp`s +* [x] `AR` sampling * [ ] `NAR-len` demasking sampling * [ ] `NAR` sampling * [ ] decode audio to disk -* [ ] a functional CLI -* [ ] quantize the model (properly) \ No newline at end of file +* [ ] a functional CLI \ No newline at end of file diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 19efd8e..bb8e540 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -168,7 +168,7 @@ int main(int argc, char ** argv) { std::vector> response_tokens = { {922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665}, }; - std::string model_path = "./vall_e/Vall_E-238M-F16.gguf"; + std::string model_path = "./vall_e/Vall_E-238M-Q8_0.gguf"; // load dynamic backends ggml_backend_load_all(); @@ -207,48 +207,82 @@ int main(int argc, char ** argv) { // prepare batch auto n_embd = llama_n_embd( model ); auto n_vocab = llama_n_vocab( model ); - // float* embd = (float*) llama_get_embedding_weights( model )->data; - float* embds = (float*) (model->tok_embd->data); + llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx ); + + // grab input embeddings + std::vector embds( n_embd * n_vocab ); + auto* qtype = ggml_get_type_traits(model->tok_embd->type); + // dequantize if needed + if ( ggml_is_quantized(model->tok_embd->type) ) { + qtype->to_float(model->tok_embd->data, embds.data(), embds.size()); + } // to-do: derive these offsets from the tokenizer itself - float* text_embds = embds + (0 * n_embd); // - float* rvq_level_embd = embds + (17666 * n_embd); // <|RVQ:0> - float* len_embd = embds + (17674 * n_embd); // <|len:0|> - float* lang_embd = embds + (17686 * n_embd); // <|lang:en|> - float* task_embd = embds + (17692 * n_embd); // <|task:tts|> - float* sep_embd = embds + (17685 * n_embd); // <|sep|> + // to-do: clean this up, probably make it at parity to inputs_to_embeddings + int text_embd_start = 0; // + int rvq_level_embd_start = 17666; // <|RVQ:0> + int len_embd_start = 17674; // <|len:0|> + int lang_embd_start = 17686; // <|lang:en|> + int task_embd_start = 17692; // <|task:tts|> + int sep_embd_start = 17685; // <|sep|> + int prom_embd_start[] = { + 256 + (1024 * 0), // <|P|0:0|> + 256 + (1024 * 1), // <|P|1:0|> + 256 + (1024 * 2), // <|P|2:0|> + 256 + (1024 * 3), // <|P|3:0|> + 256 + (1024 * 4), // <|P|4:0|> + 256 + (1024 * 5), // <|P|5:0|> + 256 + (1024 * 6), // <|P|6:0|> + 256 + (1024 * 7), // <|P|7:0|> + }; + int resp_embd_start[] = { + 8448, // <|AR|0:0|> + 9473, // <|NAR|0:0|> + 10498 + (1024 * 0), // <|NAR|0:1|> + 10498 + (1024 * 1), // <|NAR|1:2|> + 10498 + (1024 * 2), // <|NAR|2:3|> + 10498 + (1024 * 3), // <|NAR|3:4|> + 10498 + (1024 * 4), // <|NAR|4:5|> + 10498 + (1024 * 5), // <|NAR|5:6|> + 10498 + (1024 * 6), // <|NAR|6:7|> + }; + + float* text_embds = &embds[text_embd_start * n_embd]; + float* rvq_level_embd = &embds[rvq_level_embd_start * n_embd]; + float* len_embd = &embds[len_embd_start * n_embd]; + float* lang_embd = &embds[lang_embd_start * n_embd]; + float* task_embd = &embds[task_embd_start * n_embd]; + float* sep_embd = &embds[sep_embd_start * n_embd]; float* prom_embds[] = { - embds + (256 + (1024 * 0) * n_embd), // <|P|0:0|> - embds + (256 + (1024 * 1) * n_embd), // <|P|1:0|> - embds + (256 + (1024 * 2) * n_embd), // <|P|2:0|> - embds + (256 + (1024 * 3) * n_embd), // <|P|3:0|> - embds + (256 + (1024 * 4) * n_embd), // <|P|4:0|> - embds + (256 + (1024 * 5) * n_embd), // <|P|5:0|> - embds + (256 + (1024 * 6) * n_embd), // <|P|6:0|> - embds + (256 + (1024 * 7) * n_embd), // <|P|7:0|> + &embds[prom_embd_start[0] * n_embd], + &embds[prom_embd_start[1] * n_embd], + &embds[prom_embd_start[2] * n_embd], + &embds[prom_embd_start[3] * n_embd], + &embds[prom_embd_start[4] * n_embd], + &embds[prom_embd_start[5] * n_embd], + &embds[prom_embd_start[6] * n_embd], + &embds[prom_embd_start[7] * n_embd], }; float* resps_embds[] = { - embds + (8448 * n_embd), // <|AR|0:0|> - embds + (9473 * n_embd), // <|NAR|0:0|> - embds + (10498 + (1024 * 0) * n_embd), // <|NAR|0:1|> - embds + (10498 + (1024 * 1) * n_embd), // <|NAR|1:2|> - embds + (10498 + (1024 * 2) * n_embd), // <|NAR|2:3|> - embds + (10498 + (1024 * 3) * n_embd), // <|NAR|3:4|> - embds + (10498 + (1024 * 4) * n_embd), // <|NAR|4:5|> - embds + (10498 + (1024 * 5) * n_embd), // <|NAR|5:6|> - embds + (10498 + (1024 * 6) * n_embd), // <|NAR|6:7|> + &embds[resp_embd_start[0] * n_embd], + &embds[resp_embd_start[1] * n_embd], + &embds[resp_embd_start[2] * n_embd], + &embds[resp_embd_start[3] * n_embd], + &embds[resp_embd_start[4] * n_embd], + &embds[resp_embd_start[5] * n_embd], + &embds[resp_embd_start[6] * n_embd], + &embds[resp_embd_start[7] * n_embd], + &embds[resp_embd_start[8] * n_embd], }; - - llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx ); + + // insert into batch { // keeps track of the position for each sequence size_t pos = 0; // insert text tokens - for ( auto& id : phoneme_tokens ) { - batch_add( batch, id, n_embd, text_embds, pos++, false ); - } + for ( auto& id : phoneme_tokens ) batch_add( batch, id, n_embd, text_embds, pos++, false ); batch_add( batch, 0, n_embd, sep_embd, pos++, false ); pos = 0; // insert lang token @@ -262,18 +296,14 @@ int main(int argc, char ** argv) { // insert prom tokens // to-do: handle summing for ( auto l = 0; l < prompt_tokens.size(); ++l ) { - for ( auto& id : prompt_tokens[l] ) { - batch_add( batch, id, n_embd, prom_embds[l], pos++, false ); - } + for ( auto& id : prompt_tokens[l] ) batch_add( batch, id, n_embd, prom_embds[l], pos++, false ); } batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar ); pos = 0; // fill in masked tokens if ( !is_ar ) { - for ( auto i = 0; i < response_tokens[0].size(); ++i ) { - batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true ); - } + for ( auto i = 0; i < response_tokens[0].size(); ++i ) batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true ); } pos = 0; } @@ -293,7 +323,7 @@ int main(int argc, char ** argv) { // align to AR's classifier // to-do: derive from tokenizer - int range[] = { 8448, 8448 + 1024 }; // { <|AR|0:0|>, <|AR|0:STOP|> } + int range[] = { resp_embd_start[0], resp_embd_start[1] }; auto* logits = llama_get_logits_ith( ctx, -1 ); for ( auto i = 0; i < n_vocab; ++i ) { if ( i < range[0] || i >= range[1] ) { @@ -305,7 +335,7 @@ int main(int argc, char ** argv) { auto t = llama_sampler_sample(smpl, ctx, -1); // is stop token - if ( t == 9472 ) { // <|AR|0:STOP|> + if ( t == resp_embd_start[1] - 1 ) { // <|AR|0:STOP|> break; }