i hate learning APIs so much

This commit is contained in:
mrq 2024-12-21 19:40:19 -06:00
parent 1b4a69ce29
commit 70a0f5724b

View File

@ -260,7 +260,7 @@ const int INFERENCE_MODE_NAR = 4;
const int MODALITY_AR_NAR = 0; const int MODALITY_AR_NAR = 0;
const int MODALITY_NAR_LEN = 0; const int MODALITY_NAR_LEN = 0;
const int MAX_DURATION = 75 * 12; const int MAX_DURATION = 75; // * 12;
// sums embeddings over a 2D "tensor" // sums embeddings over a 2D "tensor"
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) { std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) {
@ -417,7 +417,8 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
if ( verbose ) { if ( verbose ) {
// print token for debugging // print token for debugging
char buf[256]; char buf[256];
llama_token_to_piece( model, t, buf, sizeof(buf), 0, true ); int n = llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
if ( n < 256 ) buf[n] = '\0';
printf("%s\n", buf ); printf("%s\n", buf );
} }
@ -459,9 +460,9 @@ int main(int argc, char ** argv) {
input_t input{}; input_t input{};
embeddings_t embeddings_map{}; embeddings_t embeddings_map{};
input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
std::string vall_e_model_path = "./data/vall_e-q8_0.gguf"; std::string vall_e_model_path = "./data/vall_e-F16.gguf";
std::string encodec_model_path = "./data/encodec.bin"; std::string encodec_model_path = "./data/encodec.bin";
std::string input_prompt_path = "./data/prom.wav"; std::string input_prompt_path = "./data/prom.wav";
std::string output_response_path = "./data/resp.wav"; std::string output_response_path = "./data/resp.wav";
@ -496,6 +497,7 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 22500; ctx_params.n_ctx = 22500;
ctx_params.n_batch = 22500; ctx_params.n_batch = 22500;
ctx_params.n_ubatch = 22500;
ctx_params.no_perf = false; ctx_params.no_perf = false;
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
@ -548,7 +550,7 @@ int main(int argc, char ** argv) {
} }
} }
// cap for now // cap for now
if ( len > MAX_DURATION ) len = MAX_DURATION; if ( len <= 0 || len > MAX_DURATION ) len = MAX_DURATION;
// fill with mask tokens // fill with mask tokens
input.resp.resize(1); input.resp.resize(1);