This commit is contained in:
mrq 2024-12-25 00:36:19 -06:00
parent b9d2cd5513
commit 25a02f2c3f
2 changed files with 6 additions and 5 deletions

View File

@ -601,7 +601,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
is_masked[idx] = true; is_masked[idx] = true;
} }
if ( verbose ) print_tokens( output_tokens, "Masked tokens: " ); if ( verbose ) print_tokens( output_tokens, std::string("[")+std::to_string(step)+"/"+std::to_string(steps)+"] Masked tokens: " );
// update batch // update batch
// to-do: only update the embeddings instead // to-do: only update the embeddings instead
@ -666,7 +666,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
llama_sampler_free(smpl); llama_sampler_free(smpl);
if ( verbose ) print_tokens( output_tokens ); if ( verbose ) print_tokens( output_tokens, std::string("[")+std::to_string(step)+"/"+std::to_string(steps)+"]: " );
} }
} else if ( mode == INFERENCE_MODE_NAR ) { } else if ( mode == INFERENCE_MODE_NAR ) {
// to-do: assert n_outputs == inputs.resp[rvq_l-1].size() // to-do: assert n_outputs == inputs.resp[rvq_l-1].size()
@ -840,7 +840,7 @@ bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_
} else if (arg == "-ms" || arg == "--max-steps") { } else if (arg == "-ms" || arg == "--max-steps") {
args.max_steps = std::stoi(argv[++i]); args.max_steps = std::stoi(argv[++i]);
} else if (arg == "-md" || arg == "--max-duration") { } else if (arg == "-md" || arg == "--max-duration") {
args.max_duration = std::stoi(argv[++i]); args.max_duration = std::stoi(argv[++i]) * ENCODEC_FRAMES_PER_SECOND;
} else if (arg == "-i" || arg == "--input") { } else if (arg == "-i" || arg == "--input") {
args.prompt_path = argv[++i]; args.prompt_path = argv[++i];
} else if (arg == "-o" || arg == "--output") { } else if (arg == "-o" || arg == "--output") {

View File

@ -34,7 +34,8 @@ const int INFERENCE_MODE_NAR = 3;
const int MODALITY_AR_NAR = 0; const int MODALITY_AR_NAR = 0;
const int MODALITY_NAR_LEN = 1; const int MODALITY_NAR_LEN = 1;
const int MAX_DURATION = 75 * 12; const int ENCODEC_FRAMES_PER_SECOND = 75;
const int MAX_DURATION = ENCODEC_FRAMES_PER_SECOND * 12;
const int CTX_SIZE = 2048; const int CTX_SIZE = 2048;
const int N_THREADS = 8; const int N_THREADS = 8;
const int N_GPU_LAYERS = 99; const int N_GPU_LAYERS = 99;
@ -100,7 +101,7 @@ struct vall_e_args_t {
std::string language = "en"; std::string language = "en";
int modality = MODALITY_NAR_LEN; int modality = MODALITY_NAR_LEN;
int max_steps = 30; int max_steps = 30;
int max_duration = 75 * 12; int max_duration = MAX_DURATION;
}; };
// stores everything needed for vall_e.cpp // stores everything needed for vall_e.cpp
struct vall_e_context_t { struct vall_e_context_t {