This commit is contained in:
mrq 2024-12-21 11:56:22 -06:00
parent 5788db849b
commit 979c1f797c
2 changed files with 72 additions and 42 deletions

View File

@ -15,7 +15,7 @@ Probably something like:
* [x] converted model to GGUF * [x] converted model to GGUF
* [ ] convert it without modifying any of the existing code * [ ] convert it without modifying any of the existing code
* [x] basic framework * [x] basic framework
* [x] load the model * [x] load the quantized model
* [x] orchestrate the required embeddings * [x] orchestrate the required embeddings
* [x] juggle the output head / classifier properly * [x] juggle the output head / classifier properly
* [ ] phonemize text * [ ] phonemize text
@ -23,8 +23,8 @@ Probably something like:
* [ ] load audio from disk * [ ] load audio from disk
* [ ] encode audio * [ ] encode audio
* [ ] sum embeddings for the `prom` and prior `resp`s * [ ] sum embeddings for the `prom` and prior `resp`s
* [x] `AR` sampling
* [ ] `NAR-len` demasking sampling * [ ] `NAR-len` demasking sampling
* [ ] `NAR` sampling * [ ] `NAR` sampling
* [ ] decode audio to disk * [ ] decode audio to disk
* [ ] a functional CLI * [ ] a functional CLI
* [ ] quantize the model (properly)

View File

@ -168,7 +168,7 @@ int main(int argc, char ** argv) {
std::vector<std::vector<llama_token>> response_tokens = { std::vector<std::vector<llama_token>> response_tokens = {
{922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665}, {922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665},
}; };
std::string model_path = "./vall_e/Vall_E-238M-F16.gguf"; std::string model_path = "./vall_e/Vall_E-238M-Q8_0.gguf";
// load dynamic backends // load dynamic backends
ggml_backend_load_all(); ggml_backend_load_all();
@ -207,48 +207,82 @@ int main(int argc, char ** argv) {
// prepare batch // prepare batch
auto n_embd = llama_n_embd( model ); auto n_embd = llama_n_embd( model );
auto n_vocab = llama_n_vocab( model ); auto n_vocab = llama_n_vocab( model );
// float* embd = (float*) llama_get_embedding_weights( model )->data; llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
float* embds = (float*) (model->tok_embd->data);
// grab input embeddings
std::vector<float> embds( n_embd * n_vocab );
auto* qtype = ggml_get_type_traits(model->tok_embd->type);
// dequantize if needed
if ( ggml_is_quantized(model->tok_embd->type) ) {
qtype->to_float(model->tok_embd->data, embds.data(), embds.size());
}
// to-do: derive these offsets from the tokenizer itself // to-do: derive these offsets from the tokenizer itself
float* text_embds = embds + (0 * n_embd); // <bos> // to-do: clean this up, probably make it at parity to inputs_to_embeddings
float* rvq_level_embd = embds + (17666 * n_embd); // <|RVQ:0> int text_embd_start = 0; // <unk>
float* len_embd = embds + (17674 * n_embd); // <|len:0|> int rvq_level_embd_start = 17666; // <|RVQ:0>
float* lang_embd = embds + (17686 * n_embd); // <|lang:en|> int len_embd_start = 17674; // <|len:0|>
float* task_embd = embds + (17692 * n_embd); // <|task:tts|> int lang_embd_start = 17686; // <|lang:en|>
float* sep_embd = embds + (17685 * n_embd); // <|sep|> int task_embd_start = 17692; // <|task:tts|>
int sep_embd_start = 17685; // <|sep|>
int prom_embd_start[] = {
256 + (1024 * 0), // <|P|0:0|>
256 + (1024 * 1), // <|P|1:0|>
256 + (1024 * 2), // <|P|2:0|>
256 + (1024 * 3), // <|P|3:0|>
256 + (1024 * 4), // <|P|4:0|>
256 + (1024 * 5), // <|P|5:0|>
256 + (1024 * 6), // <|P|6:0|>
256 + (1024 * 7), // <|P|7:0|>
};
int resp_embd_start[] = {
8448, // <|AR|0:0|>
9473, // <|NAR|0:0|>
10498 + (1024 * 0), // <|NAR|0:1|>
10498 + (1024 * 1), // <|NAR|1:2|>
10498 + (1024 * 2), // <|NAR|2:3|>
10498 + (1024 * 3), // <|NAR|3:4|>
10498 + (1024 * 4), // <|NAR|4:5|>
10498 + (1024 * 5), // <|NAR|5:6|>
10498 + (1024 * 6), // <|NAR|6:7|>
};
float* text_embds = &embds[text_embd_start * n_embd];
float* rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
float* len_embd = &embds[len_embd_start * n_embd];
float* lang_embd = &embds[lang_embd_start * n_embd];
float* task_embd = &embds[task_embd_start * n_embd];
float* sep_embd = &embds[sep_embd_start * n_embd];
float* prom_embds[] = { float* prom_embds[] = {
embds + (256 + (1024 * 0) * n_embd), // <|P|0:0|> &embds[prom_embd_start[0] * n_embd],
embds + (256 + (1024 * 1) * n_embd), // <|P|1:0|> &embds[prom_embd_start[1] * n_embd],
embds + (256 + (1024 * 2) * n_embd), // <|P|2:0|> &embds[prom_embd_start[2] * n_embd],
embds + (256 + (1024 * 3) * n_embd), // <|P|3:0|> &embds[prom_embd_start[3] * n_embd],
embds + (256 + (1024 * 4) * n_embd), // <|P|4:0|> &embds[prom_embd_start[4] * n_embd],
embds + (256 + (1024 * 5) * n_embd), // <|P|5:0|> &embds[prom_embd_start[5] * n_embd],
embds + (256 + (1024 * 6) * n_embd), // <|P|6:0|> &embds[prom_embd_start[6] * n_embd],
embds + (256 + (1024 * 7) * n_embd), // <|P|7:0|> &embds[prom_embd_start[7] * n_embd],
}; };
float* resps_embds[] = { float* resps_embds[] = {
embds + (8448 * n_embd), // <|AR|0:0|> &embds[resp_embd_start[0] * n_embd],
embds + (9473 * n_embd), // <|NAR|0:0|> &embds[resp_embd_start[1] * n_embd],
embds + (10498 + (1024 * 0) * n_embd), // <|NAR|0:1|> &embds[resp_embd_start[2] * n_embd],
embds + (10498 + (1024 * 1) * n_embd), // <|NAR|1:2|> &embds[resp_embd_start[3] * n_embd],
embds + (10498 + (1024 * 2) * n_embd), // <|NAR|2:3|> &embds[resp_embd_start[4] * n_embd],
embds + (10498 + (1024 * 3) * n_embd), // <|NAR|3:4|> &embds[resp_embd_start[5] * n_embd],
embds + (10498 + (1024 * 4) * n_embd), // <|NAR|4:5|> &embds[resp_embd_start[6] * n_embd],
embds + (10498 + (1024 * 5) * n_embd), // <|NAR|5:6|> &embds[resp_embd_start[7] * n_embd],
embds + (10498 + (1024 * 6) * n_embd), // <|NAR|6:7|> &embds[resp_embd_start[8] * n_embd],
}; };
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx ); // insert into batch
{ {
// keeps track of the position for each sequence // keeps track of the position for each sequence
size_t pos = 0; size_t pos = 0;
// insert text tokens // insert text tokens
for ( auto& id : phoneme_tokens ) { for ( auto& id : phoneme_tokens ) batch_add( batch, id, n_embd, text_embds, pos++, false );
batch_add( batch, id, n_embd, text_embds, pos++, false );
}
batch_add( batch, 0, n_embd, sep_embd, pos++, false ); batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0; pos = 0;
// insert lang token // insert lang token
@ -262,18 +296,14 @@ int main(int argc, char ** argv) {
// insert prom tokens // insert prom tokens
// to-do: handle summing // to-do: handle summing
for ( auto l = 0; l < prompt_tokens.size(); ++l ) { for ( auto l = 0; l < prompt_tokens.size(); ++l ) {
for ( auto& id : prompt_tokens[l] ) { for ( auto& id : prompt_tokens[l] ) batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
}
} }
batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar ); batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar );
pos = 0; pos = 0;
// fill in masked tokens // fill in masked tokens
if ( !is_ar ) { if ( !is_ar ) {
for ( auto i = 0; i < response_tokens[0].size(); ++i ) { for ( auto i = 0; i < response_tokens[0].size(); ++i ) batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
}
} }
pos = 0; pos = 0;
} }
@ -293,7 +323,7 @@ int main(int argc, char ** argv) {
// align to AR's classifier // align to AR's classifier
// to-do: derive from tokenizer // to-do: derive from tokenizer
int range[] = { 8448, 8448 + 1024 }; // { <|AR|0:0|>, <|AR|0:STOP|> } int range[] = { resp_embd_start[0], resp_embd_start[1] };
auto* logits = llama_get_logits_ith( ctx, -1 ); auto* logits = llama_get_logits_ith( ctx, -1 );
for ( auto i = 0; i < n_vocab; ++i ) { for ( auto i = 0; i < n_vocab; ++i ) {
if ( i < range[0] || i >= range[1] ) { if ( i < range[0] || i >= range[1] ) {
@ -305,7 +335,7 @@ int main(int argc, char ** argv) {
auto t = llama_sampler_sample(smpl, ctx, -1); auto t = llama_sampler_sample(smpl, ctx, -1);
// is stop token // is stop token
if ( t == 9472 ) { // <|AR|0:STOP|> if ( t == resp_embd_start[1] - 1 ) { // <|AR|0:STOP|>
break; break;
} }