From 25a02f2c3fecaa964bef6903d96c093f023849a4 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 25 Dec 2024 00:36:19 -0600 Subject: [PATCH] oops --- vall_e.cpp/vall_e.cpp | 6 +++--- vall_e.cpp/vall_e.h | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 4389c52..81f8336 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -601,7 +601,7 @@ std::vector VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t is_masked[idx] = true; } - if ( verbose ) print_tokens( output_tokens, "Masked tokens: " ); + if ( verbose ) print_tokens( output_tokens, std::string("[")+std::to_string(step)+"/"+std::to_string(steps)+"] Masked tokens: " ); // update batch // to-do: only update the embeddings instead @@ -666,7 +666,7 @@ std::vector VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t llama_sampler_free(smpl); - if ( verbose ) print_tokens( output_tokens ); + if ( verbose ) print_tokens( output_tokens, std::string("[")+std::to_string(step)+"/"+std::to_string(steps)+"]: " ); } } else if ( mode == INFERENCE_MODE_NAR ) { // to-do: assert n_outputs == inputs.resp[rvq_l-1].size() @@ -840,7 +840,7 @@ bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_ } else if (arg == "-ms" || arg == "--max-steps") { args.max_steps = std::stoi(argv[++i]); } else if (arg == "-md" || arg == "--max-duration") { - args.max_duration = std::stoi(argv[++i]); + args.max_duration = std::stoi(argv[++i]) * ENCODEC_FRAMES_PER_SECOND; } else if (arg == "-i" || arg == "--input") { args.prompt_path = argv[++i]; } else if (arg == "-o" || arg == "--output") { diff --git a/vall_e.cpp/vall_e.h b/vall_e.cpp/vall_e.h index 55403f2..5fae805 100644 --- a/vall_e.cpp/vall_e.h +++ b/vall_e.cpp/vall_e.h @@ -34,7 +34,8 @@ const int INFERENCE_MODE_NAR = 3; const int MODALITY_AR_NAR = 0; const int MODALITY_NAR_LEN = 1; -const int MAX_DURATION = 75 * 12; +const int ENCODEC_FRAMES_PER_SECOND = 75; +const int MAX_DURATION = ENCODEC_FRAMES_PER_SECOND * 12; const int CTX_SIZE = 2048; const int N_THREADS = 8; const int N_GPU_LAYERS = 99; @@ -100,7 +101,7 @@ struct vall_e_args_t { std::string language = "en"; int modality = MODALITY_NAR_LEN; int max_steps = 30; - int max_duration = 75 * 12; + int max_duration = MAX_DURATION; }; // stores everything needed for vall_e.cpp struct vall_e_context_t {