quant
This commit is contained in:
parent
5788db849b
commit
979c1f797c
|
@ -15,7 +15,7 @@ Probably something like:
|
|||
* [x] converted model to GGUF
|
||||
* [ ] convert it without modifying any of the existing code
|
||||
* [x] basic framework
|
||||
* [x] load the model
|
||||
* [x] load the quantized model
|
||||
* [x] orchestrate the required embeddings
|
||||
* [x] juggle the output head / classifier properly
|
||||
* [ ] phonemize text
|
||||
|
@ -23,8 +23,8 @@ Probably something like:
|
|||
* [ ] load audio from disk
|
||||
* [ ] encode audio
|
||||
* [ ] sum embeddings for the `prom` and prior `resp`s
|
||||
* [x] `AR` sampling
|
||||
* [ ] `NAR-len` demasking sampling
|
||||
* [ ] `NAR` sampling
|
||||
* [ ] decode audio to disk
|
||||
* [ ] a functional CLI
|
||||
* [ ] quantize the model (properly)
|
||||
* [ ] a functional CLI
|
|
@ -168,7 +168,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<std::vector<llama_token>> response_tokens = {
|
||||
{922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665},
|
||||
};
|
||||
std::string model_path = "./vall_e/Vall_E-238M-F16.gguf";
|
||||
std::string model_path = "./vall_e/Vall_E-238M-Q8_0.gguf";
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
@ -207,48 +207,82 @@ int main(int argc, char ** argv) {
|
|||
// prepare batch
|
||||
auto n_embd = llama_n_embd( model );
|
||||
auto n_vocab = llama_n_vocab( model );
|
||||
// float* embd = (float*) llama_get_embedding_weights( model )->data;
|
||||
float* embds = (float*) (model->tok_embd->data);
|
||||
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
|
||||
|
||||
// grab input embeddings
|
||||
std::vector<float> embds( n_embd * n_vocab );
|
||||
auto* qtype = ggml_get_type_traits(model->tok_embd->type);
|
||||
// dequantize if needed
|
||||
if ( ggml_is_quantized(model->tok_embd->type) ) {
|
||||
qtype->to_float(model->tok_embd->data, embds.data(), embds.size());
|
||||
}
|
||||
|
||||
// to-do: derive these offsets from the tokenizer itself
|
||||
float* text_embds = embds + (0 * n_embd); // <bos>
|
||||
float* rvq_level_embd = embds + (17666 * n_embd); // <|RVQ:0>
|
||||
float* len_embd = embds + (17674 * n_embd); // <|len:0|>
|
||||
float* lang_embd = embds + (17686 * n_embd); // <|lang:en|>
|
||||
float* task_embd = embds + (17692 * n_embd); // <|task:tts|>
|
||||
float* sep_embd = embds + (17685 * n_embd); // <|sep|>
|
||||
// to-do: clean this up, probably make it at parity to inputs_to_embeddings
|
||||
int text_embd_start = 0; // <unk>
|
||||
int rvq_level_embd_start = 17666; // <|RVQ:0>
|
||||
int len_embd_start = 17674; // <|len:0|>
|
||||
int lang_embd_start = 17686; // <|lang:en|>
|
||||
int task_embd_start = 17692; // <|task:tts|>
|
||||
int sep_embd_start = 17685; // <|sep|>
|
||||
int prom_embd_start[] = {
|
||||
256 + (1024 * 0), // <|P|0:0|>
|
||||
256 + (1024 * 1), // <|P|1:0|>
|
||||
256 + (1024 * 2), // <|P|2:0|>
|
||||
256 + (1024 * 3), // <|P|3:0|>
|
||||
256 + (1024 * 4), // <|P|4:0|>
|
||||
256 + (1024 * 5), // <|P|5:0|>
|
||||
256 + (1024 * 6), // <|P|6:0|>
|
||||
256 + (1024 * 7), // <|P|7:0|>
|
||||
};
|
||||
int resp_embd_start[] = {
|
||||
8448, // <|AR|0:0|>
|
||||
9473, // <|NAR|0:0|>
|
||||
10498 + (1024 * 0), // <|NAR|0:1|>
|
||||
10498 + (1024 * 1), // <|NAR|1:2|>
|
||||
10498 + (1024 * 2), // <|NAR|2:3|>
|
||||
10498 + (1024 * 3), // <|NAR|3:4|>
|
||||
10498 + (1024 * 4), // <|NAR|4:5|>
|
||||
10498 + (1024 * 5), // <|NAR|5:6|>
|
||||
10498 + (1024 * 6), // <|NAR|6:7|>
|
||||
};
|
||||
|
||||
float* text_embds = &embds[text_embd_start * n_embd];
|
||||
float* rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
|
||||
float* len_embd = &embds[len_embd_start * n_embd];
|
||||
float* lang_embd = &embds[lang_embd_start * n_embd];
|
||||
float* task_embd = &embds[task_embd_start * n_embd];
|
||||
float* sep_embd = &embds[sep_embd_start * n_embd];
|
||||
|
||||
float* prom_embds[] = {
|
||||
embds + (256 + (1024 * 0) * n_embd), // <|P|0:0|>
|
||||
embds + (256 + (1024 * 1) * n_embd), // <|P|1:0|>
|
||||
embds + (256 + (1024 * 2) * n_embd), // <|P|2:0|>
|
||||
embds + (256 + (1024 * 3) * n_embd), // <|P|3:0|>
|
||||
embds + (256 + (1024 * 4) * n_embd), // <|P|4:0|>
|
||||
embds + (256 + (1024 * 5) * n_embd), // <|P|5:0|>
|
||||
embds + (256 + (1024 * 6) * n_embd), // <|P|6:0|>
|
||||
embds + (256 + (1024 * 7) * n_embd), // <|P|7:0|>
|
||||
&embds[prom_embd_start[0] * n_embd],
|
||||
&embds[prom_embd_start[1] * n_embd],
|
||||
&embds[prom_embd_start[2] * n_embd],
|
||||
&embds[prom_embd_start[3] * n_embd],
|
||||
&embds[prom_embd_start[4] * n_embd],
|
||||
&embds[prom_embd_start[5] * n_embd],
|
||||
&embds[prom_embd_start[6] * n_embd],
|
||||
&embds[prom_embd_start[7] * n_embd],
|
||||
};
|
||||
float* resps_embds[] = {
|
||||
embds + (8448 * n_embd), // <|AR|0:0|>
|
||||
embds + (9473 * n_embd), // <|NAR|0:0|>
|
||||
embds + (10498 + (1024 * 0) * n_embd), // <|NAR|0:1|>
|
||||
embds + (10498 + (1024 * 1) * n_embd), // <|NAR|1:2|>
|
||||
embds + (10498 + (1024 * 2) * n_embd), // <|NAR|2:3|>
|
||||
embds + (10498 + (1024 * 3) * n_embd), // <|NAR|3:4|>
|
||||
embds + (10498 + (1024 * 4) * n_embd), // <|NAR|4:5|>
|
||||
embds + (10498 + (1024 * 5) * n_embd), // <|NAR|5:6|>
|
||||
embds + (10498 + (1024 * 6) * n_embd), // <|NAR|6:7|>
|
||||
&embds[resp_embd_start[0] * n_embd],
|
||||
&embds[resp_embd_start[1] * n_embd],
|
||||
&embds[resp_embd_start[2] * n_embd],
|
||||
&embds[resp_embd_start[3] * n_embd],
|
||||
&embds[resp_embd_start[4] * n_embd],
|
||||
&embds[resp_embd_start[5] * n_embd],
|
||||
&embds[resp_embd_start[6] * n_embd],
|
||||
&embds[resp_embd_start[7] * n_embd],
|
||||
&embds[resp_embd_start[8] * n_embd],
|
||||
};
|
||||
|
||||
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
|
||||
|
||||
// insert into batch
|
||||
{
|
||||
// keeps track of the position for each sequence
|
||||
size_t pos = 0;
|
||||
|
||||
// insert text tokens
|
||||
for ( auto& id : phoneme_tokens ) {
|
||||
batch_add( batch, id, n_embd, text_embds, pos++, false );
|
||||
}
|
||||
for ( auto& id : phoneme_tokens ) batch_add( batch, id, n_embd, text_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
|
||||
pos = 0;
|
||||
// insert lang token
|
||||
|
@ -262,18 +296,14 @@ int main(int argc, char ** argv) {
|
|||
// insert prom tokens
|
||||
// to-do: handle summing
|
||||
for ( auto l = 0; l < prompt_tokens.size(); ++l ) {
|
||||
for ( auto& id : prompt_tokens[l] ) {
|
||||
batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
|
||||
}
|
||||
for ( auto& id : prompt_tokens[l] ) batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
|
||||
}
|
||||
batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar );
|
||||
pos = 0;
|
||||
|
||||
// fill in masked tokens
|
||||
if ( !is_ar ) {
|
||||
for ( auto i = 0; i < response_tokens[0].size(); ++i ) {
|
||||
batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
|
||||
}
|
||||
for ( auto i = 0; i < response_tokens[0].size(); ++i ) batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
|
||||
}
|
||||
pos = 0;
|
||||
}
|
||||
|
@ -293,7 +323,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// align to AR's classifier
|
||||
// to-do: derive from tokenizer
|
||||
int range[] = { 8448, 8448 + 1024 }; // { <|AR|0:0|>, <|AR|0:STOP|> }
|
||||
int range[] = { resp_embd_start[0], resp_embd_start[1] };
|
||||
auto* logits = llama_get_logits_ith( ctx, -1 );
|
||||
for ( auto i = 0; i < n_vocab; ++i ) {
|
||||
if ( i < range[0] || i >= range[1] ) {
|
||||
|
@ -305,7 +335,7 @@ int main(int argc, char ** argv) {
|
|||
auto t = llama_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
// is stop token
|
||||
if ( t == 9472 ) { // <|AR|0:STOP|>
|
||||
if ( t == resp_embd_start[1] - 1 ) { // <|AR|0:STOP|>
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user