vall_e.cpp cli

This commit is contained in:
mrq 2024-12-25 00:28:34 -06:00
parent 59f56ad099
commit b9d2cd5513
3 changed files with 117 additions and 37 deletions

View File

@ -42,9 +42,9 @@ Run `make`.
* [x] working `NAR` output * [x] working `NAR` output
* [x] `NAR` sampling * [x] `NAR` sampling
* [x] decode audio to disk * [x] decode audio to disk
* [ ] a functional CLI * [x] a functional CLI
* [x] actually make it work * [x] actually make it work
* [ ] clean up to make the code usable * [x] clean up to make the code usable elsewhere
* [ ] feature parity with the PyTorch version * [ ] feature parity with the PyTorch version
* [ ] vocos * [ ] vocos
* [ ] additional tasks (`stt`, `ns`, `sr`, samplers) * [ ] additional tasks (`stt`, `ns`, `sr`, samplers)

View File

@ -122,7 +122,7 @@ void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
io_map.n_embd = n_embd; io_map.n_embd = n_embd;
io_map.n_vocab = n_vocab; io_map.n_vocab = n_vocab;
int32_t ctx_size = 24 * 2 * ggml_tensor_overhead(); // 24 embeddings + 24 output heads (generous) (should only really need to do this for output heads since we manually handle embeddings) size_t ctx_size = 24 * 2 * ggml_tensor_overhead(); // 24 embeddings + 24 output heads (generous) (should only really need to do this for output heads since we manually handle embeddings)
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ ctx_size, /*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
@ -546,7 +546,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
const token_t MASK_TOKEN = 1024; // token value for masking const token_t MASK_TOKEN = 1024; // token value for masking
const float PI = 3.141592653589793f; const float PI = 3.141592653589793f;
// to-do: derive from sampling arguments // to-do: derive from sampling arguments
int32_t steps = 10; // number of demasking steps int32_t steps = max_tokens;
int32_t seq_len = n_outputs; int32_t seq_len = n_outputs;
float temperature = 1.5f; float temperature = 1.5f;
float cfg_strength = 2.5f; float cfg_strength = 2.5f;
@ -717,15 +717,6 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
return output_tokens; return output_tokens;
} }
std::string string_replace( const std::string& string, const std::string& search, const std::string& replace ) {
std::string res = string;
size_t start_pos;
while ( (start_pos = res.find(search)) != std::string::npos ) {
res.replace(start_pos, search.length(), replace);
}
return res;
}
std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) { std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) {
std::vector<token_t> tokens; std::vector<token_t> tokens;
@ -791,6 +782,83 @@ std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::str
return tokens; return tokens;
} }
void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help Show this help message and exit\n");
fprintf(stderr, " -t N, --threads N\n");
fprintf(stderr, " Number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " Number of layers to offload to the GPU (default: %d)\n", params.gpu_layers);
fprintf(stderr, " -ctx N, --context-size N\n");
fprintf(stderr, " Max context size (default: %d)\n", params.ctx_size);
fprintf(stderr, " -v, --verbose\n");
fprintf(stderr, " Verbose output (default: %d)\n", params.verbose);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " VALL-E model path (default: %s)\n", params.model_path.c_str());
fprintf(stderr, " -em FNAME, --encodec-model FNAME\n");
fprintf(stderr, " Encodec model path (default: %s)\n", params.encodec_path.c_str());
fprintf(stderr, " -t TEXT, --text TEXT\n");
fprintf(stderr, " Input text prompt (default: %s)\n", args.text.c_str());
fprintf(stderr, " -l TEXT, --language TEXT\n");
fprintf(stderr, " Language for input text / output response (default: %s)\n", args.language.c_str());
fprintf(stderr, " -mode MODE, --modality MODE\n");
fprintf(stderr, " Modality for inferencing (default: %s, accepts ['ar+nar', 'nar-len'])\n", args.modality == MODALITY_NAR_LEN ? "nar-len" : "ar+nar");
fprintf(stderr, " -ms N, --max-steps N\n");
fprintf(stderr, " Max steps for `nar-len` (default: %i)\n", args.max_steps);
fprintf(stderr, " -md N, --max-duration N\n");
fprintf(stderr, " Max duration of the audio (default: %i)\n", args.max_duration);
fprintf(stderr, " -i FNAME, --input FNAME\n");
fprintf(stderr, " Input prompt wav (default: %s)\n", args.prompt_path.c_str());
fprintf(stderr, " -o FNAME, --output FNAME\n");
fprintf(stderr, " Output audio wav (default: %s)\n", args.output_path.c_str());
fprintf(stderr, "\n");
}
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
for ( int i = 1; i < argc; i++ ) {
std::string arg = argv[i];
if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
params.gpu_layers = std::stoi(argv[++i]);
} else if (arg == "-ctx" || arg == "--context-size") {
params.ctx_size = std::stoi(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "-m" || arg == "--model") {
params.model_path = argv[++i];
} else if (arg == "-em" || arg == "--encodec-model") {
params.encodec_path = argv[++i];
} else if (arg == "-t" || arg == "--text") {
args.text = argv[++i];
} else if (arg == "-l" || arg == "--language") {
args.language = argv[++i];
} else if (arg == "-mode" || arg == "--modality") {
args.modality = argv[++i] == "ar+nar" ? MODALITY_AR_NAR : MODALITY_NAR_LEN;
} else if (arg == "-ms" || arg == "--max-steps") {
args.max_steps = std::stoi(argv[++i]);
} else if (arg == "-md" || arg == "--max-duration") {
args.max_duration = std::stoi(argv[++i]);
} else if (arg == "-i" || arg == "--input") {
args.prompt_path = argv[++i];
} else if (arg == "-o" || arg == "--output") {
args.output_path = argv[++i];
} else if (arg == "-h" || arg == "--help") {
vall_e_print_usage(argv, params, args);
exit(0);
return false;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
vall_e_print_usage(argv, params, args);
exit(0);
return false;
}
}
return true;
}
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params ) { vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params ) {
vall_e_context_t* ctx = new vall_e_context_t(); vall_e_context_t* ctx = new vall_e_context_t();
ctx->params = params; ctx->params = params;
@ -813,8 +881,8 @@ vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params
ctx_params.n_ctx = params.ctx_size; ctx_params.n_ctx = params.ctx_size;
ctx_params.n_batch = params.ctx_size; ctx_params.n_batch = params.ctx_size;
ctx_params.n_ubatch = params.ctx_size; ctx_params.n_ubatch = params.ctx_size;
ctx_params.n_threads = params.cpu_threads; ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.cpu_threads; ctx_params.n_threads_batch = params.n_threads;
ctx_params.no_perf = false; ctx_params.no_perf = false;
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
@ -868,7 +936,7 @@ vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string&
return inputs; return inputs;
} }
// to-do: provide sampling params // to-do: provide sampling params
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality ) { vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_steps, int max_duration, int modality ) {
// NAR-len demasking // NAR-len demasking
std::vector<token_t> output_tokens; std::vector<token_t> output_tokens;
if ( modality == MODALITY_NAR_LEN ) { if ( modality == MODALITY_NAR_LEN ) {
@ -876,7 +944,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
int len = 0; int len = 0;
if ( !len ) { if ( !len ) {
inputs.task = "len"; inputs.task = "len";
output_tokens = generate( ctx, inputs, 5, INFERENCE_MODE_LEN ); output_tokens = generate( ctx, inputs, 5, INFERENCE_MODE_LEN, ctx->params.verbose );
{ {
int digit = 1; int digit = 1;
for (auto it = output_tokens.rbegin(); it < output_tokens.rend(); ++it) { for (auto it = output_tokens.rbegin(); it < output_tokens.rend(); ++it) {
@ -885,7 +953,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
} }
} }
// cap for now // cap for now
if ( len <= 0 || len > MAX_DURATION ) len = MAX_DURATION; if ( len <= 0 || len > max_duration ) len = max_duration;
} }
// fill with mask tokens // fill with mask tokens
inputs.resp.resize(1); inputs.resp.resize(1);
@ -897,7 +965,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
inputs.task = "tts"; inputs.task = "tts";
for ( auto l = 0; l < 8; ++l ) { for ( auto l = 0; l < 8; ++l ) {
inputs.rvq_l = l; inputs.rvq_l = l;
output_tokens = generate( ctx, inputs, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR ); output_tokens = generate( ctx, inputs, max_steps, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR, ctx->params.verbose );
if ( l == 0 ) inputs.resp.clear(); if ( l == 0 ) inputs.resp.clear();
inputs.resp.emplace_back( output_tokens ); inputs.resp.emplace_back( output_tokens );
} }
@ -906,7 +974,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
inputs.task = "tts"; inputs.task = "tts";
for ( auto l = 0; l < 8; ++l ) { for ( auto l = 0; l < 8; ++l ) {
inputs.rvq_l = l; inputs.rvq_l = l;
output_tokens = generate( ctx, inputs, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR ); output_tokens = generate( ctx, inputs, l == 0 ? max_duration : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR, ctx->params.verbose );
inputs.resp.emplace_back( output_tokens ); inputs.resp.emplace_back( output_tokens );
} }
} }
@ -926,21 +994,22 @@ int main( int argc, char** argv ) {
// to-do: parse CLI args // to-do: parse CLI args
vall_e_context_params_t params; vall_e_context_params_t params;
params.model_path = "./data/vall_e.gguf"; vall_e_args_t args;
params.encodec_path = "./data/encodec.bin";
params.gpu_layers = N_GPU_LAYERS; if ( !vall_e_args_parse( argc, argv, params, args ) ) {
params.cpu_threads = N_THREADS; fprintf(stderr, "%s: failed to parse arguments\n", __func__);
return 1;
}
vall_e_context_t* ctx = vall_e_load( params ); vall_e_context_t* ctx = vall_e_load( params );
if ( !ctx || !ctx->llama.model || !ctx->llama.ctx || !ctx->encodec.ctx ) {
fprintf(stderr, "%s: failed to initialize vall_e.cpp\n", __func__);
return 1;
}
std::string text = "Hello world."; auto inputs = vall_e_prepare_inputs( ctx, args.text, args.prompt_path, args.language );
std::string prompt_path = "./data/prom.wav"; auto output_audio_codes = vall_e_generate( ctx, inputs, args.max_steps, args.max_duration, args.modality );
std::string output_path = "./data/resp.wav"; write_audio_to_disk( decode_audio( ctx->encodec.ctx, output_audio_codes ), args.output_path );
std::string language = "en";
int modality = MODALITY_NAR_LEN;
auto inputs = vall_e_prepare_inputs( ctx, text, prompt_path, language );
auto output_audio_codes = vall_e_generate( ctx, inputs, modality );
write_audio_to_disk( decode_audio( ctx->encodec.ctx, output_audio_codes ), output_path );
vall_e_free( ctx ); vall_e_free( ctx );

View File

@ -37,7 +37,7 @@ const int MODALITY_NAR_LEN = 1;
const int MAX_DURATION = 75 * 12; const int MAX_DURATION = 75 * 12;
const int CTX_SIZE = 2048; const int CTX_SIZE = 2048;
const int N_THREADS = 8; const int N_THREADS = 8;
const int N_GPU_LAYERS = 0; const int N_GPU_LAYERS = 99;
typedef llama_token token_t; typedef llama_token token_t;
typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t; typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t;
@ -86,13 +86,22 @@ struct merge_entry_t {
}; };
struct vall_e_context_params_t { struct vall_e_context_params_t {
std::string model_path; std::string model_path = "./data/vall_e.gguf";
std::string encodec_path; std::string encodec_path = "./data/encodec.bin";
int32_t gpu_layers = N_GPU_LAYERS; int32_t gpu_layers = N_GPU_LAYERS;
int32_t cpu_threads = N_THREADS; int32_t n_threads = N_THREADS;
int32_t ctx_size = CTX_SIZE; int32_t ctx_size = CTX_SIZE;
bool verbose = false; bool verbose = false;
}; };
struct vall_e_args_t {
std::string text = "Hello world.";
std::string prompt_path = "./data/prom.wav";
std::string output_path = "./data/resp.wav";
std::string language = "en";
int modality = MODALITY_NAR_LEN;
int max_steps = 30;
int max_duration = 75 * 12;
};
// stores everything needed for vall_e.cpp // stores everything needed for vall_e.cpp
struct vall_e_context_t { struct vall_e_context_t {
vall_e_context_params_t params; vall_e_context_params_t params;
@ -151,6 +160,8 @@ int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, c
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model ); void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );
// context management // context management
void VALL_E_API vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args );
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args );
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params ); vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params );
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang ); vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang );
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN ); vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN );