ugh
This commit is contained in:
parent
70a0f5724b
commit
2542ed067d
|
@ -17,20 +17,22 @@ Run `make`.
|
||||||
## To-Do
|
## To-Do
|
||||||
|
|
||||||
* [x] converted model to GGUF
|
* [x] converted model to GGUF
|
||||||
* [ ] convert it without modifying any of the existing code
|
* [ ] convert it without modifying any of the existing code, as the tokenizer requires some care
|
||||||
* [x] basic framework
|
* [x] basic framework
|
||||||
* [x] load the quantized model
|
* [x] load the quantized model
|
||||||
* [x] orchestrate the required embeddings
|
* [x] orchestrate the required embeddings
|
||||||
* [x] juggle the output head / classifier properly
|
* [x] juggle the output head / classifier properly
|
||||||
* [ ] phonemize text
|
* [ ] phonemize text
|
||||||
|
* with the help of espeak-ng
|
||||||
* [ ] tokenize phonemes
|
* [ ] tokenize phonemes
|
||||||
|
* the tokenizer is being a huge thorn on actual sequences
|
||||||
* [x] load audio from disk
|
* [x] load audio from disk
|
||||||
* [x] encode audio
|
* [x] encode audio
|
||||||
* [x] sum embeddings for the `prom` and prior `resp`s
|
* [x] sum embeddings for the `prom` and prior `resp`s
|
||||||
* [x] `AR` sampling
|
* [x] `AR` sampling
|
||||||
* [ ] `NAR-len` demasking sampling
|
* [ ] `NAR-len` demasking sampling
|
||||||
* [ ] `NAR` sampling
|
* [x] `NAR` sampling
|
||||||
* [ ] decode audio to disk
|
* [x] decode audio to disk
|
||||||
* [ ] a functional CLI
|
* [ ] a functional CLI
|
||||||
* [ ] actually make it work
|
* [ ] actually make it work
|
||||||
* it seems naively stitching the model together isn't good enough since the output is wrong
|
* it seems naively stitching the model together isn't good enough since the output is wrong, it most likely needs training with a glued together classifier
|
|
@ -19,7 +19,8 @@
|
||||||
struct input_t {
|
struct input_t {
|
||||||
std::string task = "tts";
|
std::string task = "tts";
|
||||||
|
|
||||||
std::vector<llama_token> phonemes = {};
|
std::string phonemes = "";
|
||||||
|
std::vector<llama_token> phn = {};
|
||||||
llama_token lang = 0;
|
llama_token lang = 0;
|
||||||
llama_token rvq_l = 0;
|
llama_token rvq_l = 0;
|
||||||
std::vector<std::vector<llama_token>> prom = {};
|
std::vector<std::vector<llama_token>> prom = {};
|
||||||
|
@ -297,7 +298,7 @@ void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_ma
|
||||||
auto n_embd = embeddings_map.n_embd;
|
auto n_embd = embeddings_map.n_embd;
|
||||||
|
|
||||||
// insert text tokens
|
// insert text tokens
|
||||||
for ( auto& id : input.phonemes ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
|
for ( auto& id : input.phn ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
|
||||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
||||||
pos = 0;
|
pos = 0;
|
||||||
// insert lang token
|
// insert lang token
|
||||||
|
@ -350,7 +351,7 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
|
||||||
if ( batch.logits[i] ) ++n_logits;
|
if ( batch.logits[i] ) ++n_logits;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( verbose ) printf("Prompt size: %i | Logits: %i\n", batch.n_tokens, n_logits);
|
if ( verbose ) printf("Prompt size: %i | Outputs: %i\n", batch.n_tokens, n_logits);
|
||||||
|
|
||||||
// NAR mode, cap at one step
|
// NAR mode, cap at one step
|
||||||
if ( n_logits > 1 ) {
|
if ( n_logits > 1 ) {
|
||||||
|
@ -379,8 +380,8 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
|
||||||
|
|
||||||
stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|>
|
stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|>
|
||||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||||
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l];
|
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l-1];
|
||||||
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l];
|
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l-1];
|
||||||
|
|
||||||
embds = embeddings_map.resps_embds[2];
|
embds = embeddings_map.resps_embds[2];
|
||||||
} else if ( mode == INFERENCE_MODE_LEN ) {
|
} else if ( mode == INFERENCE_MODE_LEN ) {
|
||||||
|
@ -460,7 +461,8 @@ int main(int argc, char ** argv) {
|
||||||
input_t input{};
|
input_t input{};
|
||||||
embeddings_t embeddings_map{};
|
embeddings_t embeddings_map{};
|
||||||
|
|
||||||
input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
|
// input.phonemes = "hˈɛloː ʋˈɔrlt";
|
||||||
|
input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
|
||||||
|
|
||||||
std::string vall_e_model_path = "./data/vall_e-F16.gguf";
|
std::string vall_e_model_path = "./data/vall_e-F16.gguf";
|
||||||
std::string encodec_model_path = "./data/encodec.bin";
|
std::string encodec_model_path = "./data/encodec.bin";
|
||||||
|
@ -535,6 +537,21 @@ int main(int argc, char ** argv) {
|
||||||
// update mapping
|
// update mapping
|
||||||
embeddings_map.init( n_embd, n_vocab, embds.data() );
|
embeddings_map.init( n_embd, n_vocab, embds.data() );
|
||||||
|
|
||||||
|
// tokenize phonemes
|
||||||
|
// to-do: make this work, the vocab does not work
|
||||||
|
if ( input.phonemes != "" ) {
|
||||||
|
const int n_prompt = -llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), NULL, 0, true, true);
|
||||||
|
// allocate space for the tokens and tokenize the input.phonemes
|
||||||
|
input.phns.resize(n_prompt)
|
||||||
|
if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phns.data(), input.phns.size(), true, true) < 0) {
|
||||||
|
fprintf(stderr, "%s: error: failed to tokenize: %s\n", __func__, input.phonemes.c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for ( auto& token : input.phns ) printf("%i ", token );
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
// inference
|
// inference
|
||||||
std::vector<llama_token> output_tokens;
|
std::vector<llama_token> output_tokens;
|
||||||
// NAR-len demasking
|
// NAR-len demasking
|
||||||
|
|
Loading…
Reference in New Issue
Block a user