From 70a0f5724b0fc4c931f027def6337757fe7e70e4 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 21 Dec 2024 19:40:19 -0600 Subject: [PATCH] i hate learning APIs so much --- vall_e.cpp/vall_e.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 96007d2..90ee430 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -260,7 +260,7 @@ const int INFERENCE_MODE_NAR = 4; const int MODALITY_AR_NAR = 0; const int MODALITY_NAR_LEN = 0; -const int MAX_DURATION = 75 * 12; +const int MAX_DURATION = 75; // * 12; // sums embeddings over a 2D "tensor" std::vector> sum_embeddings( const std::vector>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) { @@ -417,7 +417,8 @@ std::vector generate( llama_context* ctx, llama_model* model, llama if ( verbose ) { // print token for debugging char buf[256]; - llama_token_to_piece( model, t, buf, sizeof(buf), 0, true ); + int n = llama_token_to_piece( model, t, buf, sizeof(buf), 0, true ); + if ( n < 256 ) buf[n] = '\0'; printf("%s\n", buf ); } @@ -459,9 +460,9 @@ int main(int argc, char ** argv) { input_t input{}; embeddings_t embeddings_map{}; - input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; + input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // hˈɛloː ʋˈɔrlt - std::string vall_e_model_path = "./data/vall_e-q8_0.gguf"; + std::string vall_e_model_path = "./data/vall_e-F16.gguf"; std::string encodec_model_path = "./data/encodec.bin"; std::string input_prompt_path = "./data/prom.wav"; std::string output_response_path = "./data/resp.wav"; @@ -496,6 +497,7 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = 22500; ctx_params.n_batch = 22500; + ctx_params.n_ubatch = 22500; ctx_params.no_perf = false; ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; @@ -548,7 +550,7 @@ int main(int argc, char ** argv) { } } // cap for now - if ( len > MAX_DURATION ) len = MAX_DURATION; + if ( len <= 0 || len > MAX_DURATION ) len = MAX_DURATION; // fill with mask tokens input.resp.resize(1);