oops
This commit is contained in:
parent
b9d2cd5513
commit
25a02f2c3f
|
@ -601,7 +601,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
||||||
is_masked[idx] = true;
|
is_masked[idx] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( verbose ) print_tokens( output_tokens, "Masked tokens: " );
|
if ( verbose ) print_tokens( output_tokens, std::string("[")+std::to_string(step)+"/"+std::to_string(steps)+"] Masked tokens: " );
|
||||||
|
|
||||||
// update batch
|
// update batch
|
||||||
// to-do: only update the embeddings instead
|
// to-do: only update the embeddings instead
|
||||||
|
@ -666,7 +666,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
||||||
|
|
||||||
llama_sampler_free(smpl);
|
llama_sampler_free(smpl);
|
||||||
|
|
||||||
if ( verbose ) print_tokens( output_tokens );
|
if ( verbose ) print_tokens( output_tokens, std::string("[")+std::to_string(step)+"/"+std::to_string(steps)+"]: " );
|
||||||
}
|
}
|
||||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||||
// to-do: assert n_outputs == inputs.resp[rvq_l-1].size()
|
// to-do: assert n_outputs == inputs.resp[rvq_l-1].size()
|
||||||
|
@ -840,7 +840,7 @@ bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_
|
||||||
} else if (arg == "-ms" || arg == "--max-steps") {
|
} else if (arg == "-ms" || arg == "--max-steps") {
|
||||||
args.max_steps = std::stoi(argv[++i]);
|
args.max_steps = std::stoi(argv[++i]);
|
||||||
} else if (arg == "-md" || arg == "--max-duration") {
|
} else if (arg == "-md" || arg == "--max-duration") {
|
||||||
args.max_duration = std::stoi(argv[++i]);
|
args.max_duration = std::stoi(argv[++i]) * ENCODEC_FRAMES_PER_SECOND;
|
||||||
} else if (arg == "-i" || arg == "--input") {
|
} else if (arg == "-i" || arg == "--input") {
|
||||||
args.prompt_path = argv[++i];
|
args.prompt_path = argv[++i];
|
||||||
} else if (arg == "-o" || arg == "--output") {
|
} else if (arg == "-o" || arg == "--output") {
|
||||||
|
|
|
@ -34,7 +34,8 @@ const int INFERENCE_MODE_NAR = 3;
|
||||||
const int MODALITY_AR_NAR = 0;
|
const int MODALITY_AR_NAR = 0;
|
||||||
const int MODALITY_NAR_LEN = 1;
|
const int MODALITY_NAR_LEN = 1;
|
||||||
|
|
||||||
const int MAX_DURATION = 75 * 12;
|
const int ENCODEC_FRAMES_PER_SECOND = 75;
|
||||||
|
const int MAX_DURATION = ENCODEC_FRAMES_PER_SECOND * 12;
|
||||||
const int CTX_SIZE = 2048;
|
const int CTX_SIZE = 2048;
|
||||||
const int N_THREADS = 8;
|
const int N_THREADS = 8;
|
||||||
const int N_GPU_LAYERS = 99;
|
const int N_GPU_LAYERS = 99;
|
||||||
|
@ -100,7 +101,7 @@ struct vall_e_args_t {
|
||||||
std::string language = "en";
|
std::string language = "en";
|
||||||
int modality = MODALITY_NAR_LEN;
|
int modality = MODALITY_NAR_LEN;
|
||||||
int max_steps = 30;
|
int max_steps = 30;
|
||||||
int max_duration = 75 * 12;
|
int max_duration = MAX_DURATION;
|
||||||
};
|
};
|
||||||
// stores everything needed for vall_e.cpp
|
// stores everything needed for vall_e.cpp
|
||||||
struct vall_e_context_t {
|
struct vall_e_context_t {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user