vall_e.cpp cli
This commit is contained in:
parent
59f56ad099
commit
b9d2cd5513
|
@ -42,9 +42,9 @@ Run `make`.
|
||||||
* [x] working `NAR` output
|
* [x] working `NAR` output
|
||||||
* [x] `NAR` sampling
|
* [x] `NAR` sampling
|
||||||
* [x] decode audio to disk
|
* [x] decode audio to disk
|
||||||
* [ ] a functional CLI
|
* [x] a functional CLI
|
||||||
* [x] actually make it work
|
* [x] actually make it work
|
||||||
* [ ] clean up to make the code usable
|
* [x] clean up to make the code usable elsewhere
|
||||||
* [ ] feature parity with the PyTorch version
|
* [ ] feature parity with the PyTorch version
|
||||||
* [ ] vocos
|
* [ ] vocos
|
||||||
* [ ] additional tasks (`stt`, `ns`, `sr`, samplers)
|
* [ ] additional tasks (`stt`, `ns`, `sr`, samplers)
|
|
@ -122,7 +122,7 @@ void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
|
||||||
io_map.n_embd = n_embd;
|
io_map.n_embd = n_embd;
|
||||||
io_map.n_vocab = n_vocab;
|
io_map.n_vocab = n_vocab;
|
||||||
|
|
||||||
int32_t ctx_size = 24 * 2 * ggml_tensor_overhead(); // 24 embeddings + 24 output heads (generous) (should only really need to do this for output heads since we manually handle embeddings)
|
size_t ctx_size = 24 * 2 * ggml_tensor_overhead(); // 24 embeddings + 24 output heads (generous) (should only really need to do this for output heads since we manually handle embeddings)
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ ctx_size,
|
/*.mem_size =*/ ctx_size,
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
|
@ -546,7 +546,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
||||||
const token_t MASK_TOKEN = 1024; // token value for masking
|
const token_t MASK_TOKEN = 1024; // token value for masking
|
||||||
const float PI = 3.141592653589793f;
|
const float PI = 3.141592653589793f;
|
||||||
// to-do: derive from sampling arguments
|
// to-do: derive from sampling arguments
|
||||||
int32_t steps = 10; // number of demasking steps
|
int32_t steps = max_tokens;
|
||||||
int32_t seq_len = n_outputs;
|
int32_t seq_len = n_outputs;
|
||||||
float temperature = 1.5f;
|
float temperature = 1.5f;
|
||||||
float cfg_strength = 2.5f;
|
float cfg_strength = 2.5f;
|
||||||
|
@ -717,15 +717,6 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
||||||
return output_tokens;
|
return output_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string string_replace( const std::string& string, const std::string& search, const std::string& replace ) {
|
|
||||||
std::string res = string;
|
|
||||||
size_t start_pos;
|
|
||||||
while ( (start_pos = res.find(search)) != std::string::npos ) {
|
|
||||||
res.replace(start_pos, search.length(), replace);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) {
|
std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) {
|
||||||
std::vector<token_t> tokens;
|
std::vector<token_t> tokens;
|
||||||
|
|
||||||
|
@ -791,6 +782,83 @@ std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::str
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
|
||||||
|
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "options:\n");
|
||||||
|
fprintf(stderr, " -h, --help Show this help message and exit\n");
|
||||||
|
fprintf(stderr, " -t N, --threads N\n");
|
||||||
|
fprintf(stderr, " Number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||||
|
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
|
||||||
|
fprintf(stderr, " Number of layers to offload to the GPU (default: %d)\n", params.gpu_layers);
|
||||||
|
fprintf(stderr, " -ctx N, --context-size N\n");
|
||||||
|
fprintf(stderr, " Max context size (default: %d)\n", params.ctx_size);
|
||||||
|
fprintf(stderr, " -v, --verbose\n");
|
||||||
|
fprintf(stderr, " Verbose output (default: %d)\n", params.verbose);
|
||||||
|
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||||
|
fprintf(stderr, " VALL-E model path (default: %s)\n", params.model_path.c_str());
|
||||||
|
fprintf(stderr, " -em FNAME, --encodec-model FNAME\n");
|
||||||
|
fprintf(stderr, " Encodec model path (default: %s)\n", params.encodec_path.c_str());
|
||||||
|
fprintf(stderr, " -t TEXT, --text TEXT\n");
|
||||||
|
fprintf(stderr, " Input text prompt (default: %s)\n", args.text.c_str());
|
||||||
|
fprintf(stderr, " -l TEXT, --language TEXT\n");
|
||||||
|
fprintf(stderr, " Language for input text / output response (default: %s)\n", args.language.c_str());
|
||||||
|
fprintf(stderr, " -mode MODE, --modality MODE\n");
|
||||||
|
fprintf(stderr, " Modality for inferencing (default: %s, accepts ['ar+nar', 'nar-len'])\n", args.modality == MODALITY_NAR_LEN ? "nar-len" : "ar+nar");
|
||||||
|
fprintf(stderr, " -ms N, --max-steps N\n");
|
||||||
|
fprintf(stderr, " Max steps for `nar-len` (default: %i)\n", args.max_steps);
|
||||||
|
fprintf(stderr, " -md N, --max-duration N\n");
|
||||||
|
fprintf(stderr, " Max duration of the audio (default: %i)\n", args.max_duration);
|
||||||
|
fprintf(stderr, " -i FNAME, --input FNAME\n");
|
||||||
|
fprintf(stderr, " Input prompt wav (default: %s)\n", args.prompt_path.c_str());
|
||||||
|
fprintf(stderr, " -o FNAME, --output FNAME\n");
|
||||||
|
fprintf(stderr, " Output audio wav (default: %s)\n", args.output_path.c_str());
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
|
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
|
||||||
|
for ( int i = 1; i < argc; i++ ) {
|
||||||
|
std::string arg = argv[i];
|
||||||
|
|
||||||
|
if (arg == "-t" || arg == "--threads") {
|
||||||
|
params.n_threads = std::stoi(argv[++i]);
|
||||||
|
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
|
||||||
|
params.gpu_layers = std::stoi(argv[++i]);
|
||||||
|
} else if (arg == "-ctx" || arg == "--context-size") {
|
||||||
|
params.ctx_size = std::stoi(argv[++i]);
|
||||||
|
} else if (arg == "-v" || arg == "--verbose") {
|
||||||
|
params.verbose = true;
|
||||||
|
} else if (arg == "-m" || arg == "--model") {
|
||||||
|
params.model_path = argv[++i];
|
||||||
|
} else if (arg == "-em" || arg == "--encodec-model") {
|
||||||
|
params.encodec_path = argv[++i];
|
||||||
|
} else if (arg == "-t" || arg == "--text") {
|
||||||
|
args.text = argv[++i];
|
||||||
|
} else if (arg == "-l" || arg == "--language") {
|
||||||
|
args.language = argv[++i];
|
||||||
|
} else if (arg == "-mode" || arg == "--modality") {
|
||||||
|
args.modality = argv[++i] == "ar+nar" ? MODALITY_AR_NAR : MODALITY_NAR_LEN;
|
||||||
|
} else if (arg == "-ms" || arg == "--max-steps") {
|
||||||
|
args.max_steps = std::stoi(argv[++i]);
|
||||||
|
} else if (arg == "-md" || arg == "--max-duration") {
|
||||||
|
args.max_duration = std::stoi(argv[++i]);
|
||||||
|
} else if (arg == "-i" || arg == "--input") {
|
||||||
|
args.prompt_path = argv[++i];
|
||||||
|
} else if (arg == "-o" || arg == "--output") {
|
||||||
|
args.output_path = argv[++i];
|
||||||
|
} else if (arg == "-h" || arg == "--help") {
|
||||||
|
vall_e_print_usage(argv, params, args);
|
||||||
|
exit(0);
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
|
vall_e_print_usage(argv, params, args);
|
||||||
|
exit(0);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params ) {
|
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params ) {
|
||||||
vall_e_context_t* ctx = new vall_e_context_t();
|
vall_e_context_t* ctx = new vall_e_context_t();
|
||||||
ctx->params = params;
|
ctx->params = params;
|
||||||
|
@ -813,8 +881,8 @@ vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params
|
||||||
ctx_params.n_ctx = params.ctx_size;
|
ctx_params.n_ctx = params.ctx_size;
|
||||||
ctx_params.n_batch = params.ctx_size;
|
ctx_params.n_batch = params.ctx_size;
|
||||||
ctx_params.n_ubatch = params.ctx_size;
|
ctx_params.n_ubatch = params.ctx_size;
|
||||||
ctx_params.n_threads = params.cpu_threads;
|
ctx_params.n_threads = params.n_threads;
|
||||||
ctx_params.n_threads_batch = params.cpu_threads;
|
ctx_params.n_threads_batch = params.n_threads;
|
||||||
ctx_params.no_perf = false;
|
ctx_params.no_perf = false;
|
||||||
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
|
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||||
|
|
||||||
|
@ -868,7 +936,7 @@ vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string&
|
||||||
return inputs;
|
return inputs;
|
||||||
}
|
}
|
||||||
// to-do: provide sampling params
|
// to-do: provide sampling params
|
||||||
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality ) {
|
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_steps, int max_duration, int modality ) {
|
||||||
// NAR-len demasking
|
// NAR-len demasking
|
||||||
std::vector<token_t> output_tokens;
|
std::vector<token_t> output_tokens;
|
||||||
if ( modality == MODALITY_NAR_LEN ) {
|
if ( modality == MODALITY_NAR_LEN ) {
|
||||||
|
@ -876,7 +944,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
||||||
int len = 0;
|
int len = 0;
|
||||||
if ( !len ) {
|
if ( !len ) {
|
||||||
inputs.task = "len";
|
inputs.task = "len";
|
||||||
output_tokens = generate( ctx, inputs, 5, INFERENCE_MODE_LEN );
|
output_tokens = generate( ctx, inputs, 5, INFERENCE_MODE_LEN, ctx->params.verbose );
|
||||||
{
|
{
|
||||||
int digit = 1;
|
int digit = 1;
|
||||||
for (auto it = output_tokens.rbegin(); it < output_tokens.rend(); ++it) {
|
for (auto it = output_tokens.rbegin(); it < output_tokens.rend(); ++it) {
|
||||||
|
@ -885,7 +953,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// cap for now
|
// cap for now
|
||||||
if ( len <= 0 || len > MAX_DURATION ) len = MAX_DURATION;
|
if ( len <= 0 || len > max_duration ) len = max_duration;
|
||||||
}
|
}
|
||||||
// fill with mask tokens
|
// fill with mask tokens
|
||||||
inputs.resp.resize(1);
|
inputs.resp.resize(1);
|
||||||
|
@ -897,7 +965,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
||||||
inputs.task = "tts";
|
inputs.task = "tts";
|
||||||
for ( auto l = 0; l < 8; ++l ) {
|
for ( auto l = 0; l < 8; ++l ) {
|
||||||
inputs.rvq_l = l;
|
inputs.rvq_l = l;
|
||||||
output_tokens = generate( ctx, inputs, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
|
output_tokens = generate( ctx, inputs, max_steps, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR, ctx->params.verbose );
|
||||||
if ( l == 0 ) inputs.resp.clear();
|
if ( l == 0 ) inputs.resp.clear();
|
||||||
inputs.resp.emplace_back( output_tokens );
|
inputs.resp.emplace_back( output_tokens );
|
||||||
}
|
}
|
||||||
|
@ -906,7 +974,7 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
||||||
inputs.task = "tts";
|
inputs.task = "tts";
|
||||||
for ( auto l = 0; l < 8; ++l ) {
|
for ( auto l = 0; l < 8; ++l ) {
|
||||||
inputs.rvq_l = l;
|
inputs.rvq_l = l;
|
||||||
output_tokens = generate( ctx, inputs, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
|
output_tokens = generate( ctx, inputs, l == 0 ? max_duration : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR, ctx->params.verbose );
|
||||||
inputs.resp.emplace_back( output_tokens );
|
inputs.resp.emplace_back( output_tokens );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -926,21 +994,22 @@ int main( int argc, char** argv ) {
|
||||||
// to-do: parse CLI args
|
// to-do: parse CLI args
|
||||||
|
|
||||||
vall_e_context_params_t params;
|
vall_e_context_params_t params;
|
||||||
params.model_path = "./data/vall_e.gguf";
|
vall_e_args_t args;
|
||||||
params.encodec_path = "./data/encodec.bin";
|
|
||||||
params.gpu_layers = N_GPU_LAYERS;
|
if ( !vall_e_args_parse( argc, argv, params, args ) ) {
|
||||||
params.cpu_threads = N_THREADS;
|
fprintf(stderr, "%s: failed to parse arguments\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
vall_e_context_t* ctx = vall_e_load( params );
|
vall_e_context_t* ctx = vall_e_load( params );
|
||||||
|
if ( !ctx || !ctx->llama.model || !ctx->llama.ctx || !ctx->encodec.ctx ) {
|
||||||
|
fprintf(stderr, "%s: failed to initialize vall_e.cpp\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::string text = "Hello world.";
|
auto inputs = vall_e_prepare_inputs( ctx, args.text, args.prompt_path, args.language );
|
||||||
std::string prompt_path = "./data/prom.wav";
|
auto output_audio_codes = vall_e_generate( ctx, inputs, args.max_steps, args.max_duration, args.modality );
|
||||||
std::string output_path = "./data/resp.wav";
|
write_audio_to_disk( decode_audio( ctx->encodec.ctx, output_audio_codes ), args.output_path );
|
||||||
std::string language = "en";
|
|
||||||
int modality = MODALITY_NAR_LEN;
|
|
||||||
|
|
||||||
auto inputs = vall_e_prepare_inputs( ctx, text, prompt_path, language );
|
|
||||||
auto output_audio_codes = vall_e_generate( ctx, inputs, modality );
|
|
||||||
write_audio_to_disk( decode_audio( ctx->encodec.ctx, output_audio_codes ), output_path );
|
|
||||||
|
|
||||||
vall_e_free( ctx );
|
vall_e_free( ctx );
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ const int MODALITY_NAR_LEN = 1;
|
||||||
const int MAX_DURATION = 75 * 12;
|
const int MAX_DURATION = 75 * 12;
|
||||||
const int CTX_SIZE = 2048;
|
const int CTX_SIZE = 2048;
|
||||||
const int N_THREADS = 8;
|
const int N_THREADS = 8;
|
||||||
const int N_GPU_LAYERS = 0;
|
const int N_GPU_LAYERS = 99;
|
||||||
|
|
||||||
typedef llama_token token_t;
|
typedef llama_token token_t;
|
||||||
typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t;
|
typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t;
|
||||||
|
@ -86,13 +86,22 @@ struct merge_entry_t {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vall_e_context_params_t {
|
struct vall_e_context_params_t {
|
||||||
std::string model_path;
|
std::string model_path = "./data/vall_e.gguf";
|
||||||
std::string encodec_path;
|
std::string encodec_path = "./data/encodec.bin";
|
||||||
int32_t gpu_layers = N_GPU_LAYERS;
|
int32_t gpu_layers = N_GPU_LAYERS;
|
||||||
int32_t cpu_threads = N_THREADS;
|
int32_t n_threads = N_THREADS;
|
||||||
int32_t ctx_size = CTX_SIZE;
|
int32_t ctx_size = CTX_SIZE;
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
};
|
};
|
||||||
|
struct vall_e_args_t {
|
||||||
|
std::string text = "Hello world.";
|
||||||
|
std::string prompt_path = "./data/prom.wav";
|
||||||
|
std::string output_path = "./data/resp.wav";
|
||||||
|
std::string language = "en";
|
||||||
|
int modality = MODALITY_NAR_LEN;
|
||||||
|
int max_steps = 30;
|
||||||
|
int max_duration = 75 * 12;
|
||||||
|
};
|
||||||
// stores everything needed for vall_e.cpp
|
// stores everything needed for vall_e.cpp
|
||||||
struct vall_e_context_t {
|
struct vall_e_context_t {
|
||||||
vall_e_context_params_t params;
|
vall_e_context_params_t params;
|
||||||
|
@ -151,6 +160,8 @@ int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, c
|
||||||
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );
|
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );
|
||||||
|
|
||||||
// context management
|
// context management
|
||||||
|
void VALL_E_API vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args );
|
||||||
|
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args );
|
||||||
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params );
|
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params );
|
||||||
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang );
|
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang );
|
||||||
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN );
|
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN );
|
||||||
|
|
Loading…
Reference in New Issue
Block a user