This commit is contained in:
mrq 2024-12-21 19:59:56 -06:00
parent 70a0f5724b
commit 2542ed067d
2 changed files with 29 additions and 10 deletions

View File

@ -17,20 +17,22 @@ Run `make`.
## To-Do
* [x] converted model to GGUF
* [ ] convert it without modifying any of the existing code
* [ ] convert it without modifying any of the existing code, as the tokenizer requires some care
* [x] basic framework
* [x] load the quantized model
* [x] orchestrate the required embeddings
* [x] juggle the output head / classifier properly
* [ ] phonemize text
* with the help of espeak-ng
* [ ] tokenize phonemes
* the tokenizer is being a huge thorn on actual sequences
* [x] load audio from disk
* [x] encode audio
* [x] sum embeddings for the `prom` and prior `resp`s
* [x] `AR` sampling
* [ ] `NAR-len` demasking sampling
* [ ] `NAR` sampling
* [ ] decode audio to disk
* [x] `NAR` sampling
* [x] decode audio to disk
* [ ] a functional CLI
* [ ] actually make it work
* it seems naively stitching the model together isn't good enough since the output is wrong
* it seems naively stitching the model together isn't good enough since the output is wrong, it most likely needs training with a glued together classifier

View File

@ -19,7 +19,8 @@
struct input_t {
std::string task = "tts";
std::vector<llama_token> phonemes = {};
std::string phonemes = "";
std::vector<llama_token> phn = {};
llama_token lang = 0;
llama_token rvq_l = 0;
std::vector<std::vector<llama_token>> prom = {};
@ -297,7 +298,7 @@ void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_ma
auto n_embd = embeddings_map.n_embd;
// insert text tokens
for ( auto& id : input.phonemes ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
for ( auto& id : input.phn ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
pos = 0;
// insert lang token
@ -350,7 +351,7 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
if ( batch.logits[i] ) ++n_logits;
}
if ( verbose ) printf("Prompt size: %i | Logits: %i\n", batch.n_tokens, n_logits);
if ( verbose ) printf("Prompt size: %i | Outputs: %i\n", batch.n_tokens, n_logits);
// NAR mode, cap at one step
if ( n_logits > 1 ) {
@ -379,8 +380,8 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|>
} else if ( mode == INFERENCE_MODE_NAR ) {
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l];
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l];
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l-1];
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l-1];
embds = embeddings_map.resps_embds[2];
} else if ( mode == INFERENCE_MODE_LEN ) {
@ -460,7 +461,8 @@ int main(int argc, char ** argv) {
input_t input{};
embeddings_t embeddings_map{};
input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
// input.phonemes = "hˈɛloː ʋˈɔrlt";
input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
std::string vall_e_model_path = "./data/vall_e-F16.gguf";
std::string encodec_model_path = "./data/encodec.bin";
@ -535,6 +537,21 @@ int main(int argc, char ** argv) {
// update mapping
embeddings_map.init( n_embd, n_vocab, embds.data() );
// tokenize phonemes
// to-do: make this work, the vocab does not work
if ( input.phonemes != "" ) {
const int n_prompt = -llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), NULL, 0, true, true);
// allocate space for the tokens and tokenize the input.phonemes
input.phns.resize(n_prompt)
if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phns.data(), input.phns.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize: %s\n", __func__, input.phonemes.c_str());
return 1;
}
for ( auto& token : input.phns ) printf("%i ", token );
printf("\n");
}
// inference
std::vector<llama_token> output_tokens;
// NAR-len demasking