This commit is contained in:
mrq 2025-04-05 10:27:07 -05:00
parent 0ede3bfc12
commit 44260f7445
4 changed files with 106 additions and 52 deletions

View File

@ -24,6 +24,7 @@ Run `make`.
## To-Do
* [ ] fix regressions that appeared for whatever reason
* it seems to be related to the demasking step, as low steps = fine, more steps = bad......
* [x] converted model to GGUF
* [x] convert it without modifying any of the existing code, as the tokenizer requires some care
* [x] basic framework

View File

@ -98,7 +98,7 @@ ggml_tensor* view_2d_tensor( struct ggml_tensor* tensor, int32_t start, int32_t
ggml_tensor* res = new ggml_tensor();
memcpy( res, tensor, sizeof(ggml_tensor) );
res->op = GGML_OP_VIEW;
res->op = GGML_OP_VIEW;
res->src[0] = tensor;
res->data += res->nb[1] * start;
@ -128,6 +128,29 @@ void print_tokens( const std::vector<token_t>& tokens, const std::string& prefix
}
printf("]\n");
}
void print_floats( const std::vector<float>& v, const std::string& prefix ) {
printf("%s[", prefix.c_str());
for ( auto i = 0; i < v.size(); ++i ) {
printf("%f%s", v[i], i + 1 < v.size() ? ", " : "");
}
printf("]\n");
}
float calculate_std(const float* data, size_t n) {
float mean = 0.0f;
for (size_t i = 0; i < n; i++) mean += data[i];
mean /= n;
float variance = 0.0f;
for (size_t i = 0; i < n; i++) {
float diff = data[i] - mean;
variance += diff * diff;
}
variance /= n;
return sqrt(variance);
}
const io_t& vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) {
return io_map.io[name];
@ -532,7 +555,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
// to-do: figure this out......
{
llama_set_causal_attn( ctx->llama.ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
// *const_cast<bool*>(&model->hparams.causal_attn) = true; // force set this
*const_cast<bool*>(&ctx->llama.model->hparams.causal_attn) = true; // force set this
}
std::vector<token_t> output_tokens;
@ -545,7 +568,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
llama_sampler * smpl = llama_sampler_chain_init(sparams);
if ( mode == INFERENCE_MODE_LEN ) {
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
} else {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(0));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
@ -582,11 +605,19 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
// to-do: assert n_outputs == inputs.resp[rvq_l-1].size()
const token_t MASK_TOKEN = 1024; // token value for masking
const float PI = 3.141592653589793f;
// to-do: derive from sampling arguments
int32_t steps = max_tokens;
int32_t seq_len = n_outputs;
float temperature = 1.5f;
int32_t top_k = 0;
float top_p = 1.0;
float temperature = 1.0f;
float cfg_strength = 2.5f;
float start_noise = 0.0f;
float end_noise = 1.0f;
bool annealed_sampling = true;
bool remasking = true;
float cfg_rescale = 0.75f;
// fill with masked tokens
output_tokens.clear();
@ -604,15 +635,17 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
// do one step on many tokens
for ( auto step = 0; step < steps; ++step ) {
float timestep = ((float)step) / steps; // to-do: align with torch.linspace
float t_norm = static_cast<float>(step) / static_cast<float>(steps - 1);
float timestep = start_noise + (end_noise - start_noise) * t_norm;
//float timestep = start_noise + (end_noise - start_noise) * ((float)step / steps);
float annealing = 1.0f - timestep;
float sampling_temperature = temperature * annealing;
float sampling_cfg_strength = timestep * cfg_strength;
float sampling_temperature = annealed_sampling ? temperature * annealing : temperature;
float sampling_cfg_strength = annealed_sampling ? timestep * cfg_strength : cfg_strength;
float noise_p = cos( timestep * PI * 0.5f );
float remask_p = 0.5f / steps;
float remask_p = remasking ? 0.5f / steps : 0.0f;
int32_t n_masked_tokens = (noise_p + remask_p) * seq_len;
if ( n_masked_tokens < 1 ) {
@ -671,10 +704,14 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
sparams.no_perf = false;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(20));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
if ( sampling_temperature == 0 ) {
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
} else {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
}
auto* logits = llama_get_logits( ctx->llama.ctx );
for ( auto idx = 0; idx < n_outputs; ++idx ) {
@ -689,8 +726,21 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
// perform softmax before modifying logits
std::vector<float> softmaxed = soft_max( n_vocab, logit );
for ( auto i = 0; i < n_vocab; ++i ) {
logit[i] = null_logit[i] + (logit[i] - null_logit[i]) * cfg_strength;
std::vector<float> summed(n_vocab);
for (int i = 0; i < n_vocab; i++) {
summed[i] = null_logit[i] + (logit[i] - null_logit[i]) * sampling_cfg_strength;
}
if (cfg_rescale > 0) {
float pos_std = calculate_std(logit, n_vocab);
float summed_std = calculate_std(summed.data(), n_vocab);
float factor = cfg_rescale * (pos_std / summed_std) + (1 - cfg_rescale);
for (int i = 0; i < n_vocab; i++) {
logit[i] = summed[i] * factor;
}
} else {
memcpy(logit, summed.data(), n_vocab * sizeof(float));
}
// sample ith token
@ -821,39 +871,39 @@ std::vector<token_t> phonemize( vall_e_context_t* ctx, const std::string& text,
}
void vall_e_print_usage( char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help Show this help message and exit\n");
fprintf(stderr, " -t N, --threads N\n");
fprintf(stderr, " Number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " Number of layers to offload to the GPU (default: %d)\n", params.gpu_layers);
fprintf(stderr, " -ctx N, --context-size N\n");
fprintf(stderr, " Max context size (default: %d)\n", params.ctx_size);
fprintf(stderr, " -v, --verbose\n");
fprintf(stderr, " Verbose output (default: %d)\n", params.verbose);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " VALL-E model path (default: %s)\n", params.model_path.c_str());
fprintf(stderr, " -em FNAME, --encodec-model FNAME\n");
fprintf(stderr, " Encodec model path (default: %s)\n", params.encodec_path.c_str());
fprintf(stderr, " -t TEXT, --text TEXT\n");
fprintf(stderr, " Input text prompt (default: %s)\n", args.text.c_str());
fprintf(stderr, " -l TEXT, --language TEXT\n");
fprintf(stderr, " Language for input text / output response (default: %s)\n", args.language.c_str());
fprintf(stderr, " -ts TASK, --task TASK\n");
fprintf(stderr, " Inferencing task (default: %s, accepts ['tts', 'stt', 'ns', 'sr'])\n", args.task.c_str());
fprintf(stderr, " -mode MODE, --modality MODE\n");
fprintf(stderr, " Modality for inferencing (default: %s, accepts ['ar+nar', 'nar-len'])\n", args.modality == MODALITY_NAR_LEN ? "nar-len" : "ar+nar");
fprintf(stderr, " -ms N, --max-steps N\n");
fprintf(stderr, " Max steps for `nar-len` (default: %i)\n", args.max_steps);
fprintf(stderr, " -md N, --max-duration N\n");
fprintf(stderr, " Max duration of the audio (default: %i)\n", args.max_duration);
fprintf(stderr, " -i FNAME, --input FNAME\n");
fprintf(stderr, " Input prompt wav (default: %s)\n", args.prompt_path.c_str());
fprintf(stderr, " -o FNAME, --output FNAME\n");
fprintf(stderr, " Output audio wav (default: %s)\n", args.output_path.c_str());
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help Show this help message and exit\n");
fprintf(stderr, " -t N, --threads N\n");
fprintf(stderr, " Number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " Number of layers to offload to the GPU (default: %d)\n", params.gpu_layers);
fprintf(stderr, " -ctx N, --context-size N\n");
fprintf(stderr, " Max context size (default: %d)\n", params.ctx_size);
fprintf(stderr, " -v, --verbose\n");
fprintf(stderr, " Verbose output (default: %d)\n", params.verbose);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " VALL-E model path (default: %s)\n", params.model_path.c_str());
fprintf(stderr, " -em FNAME, --encodec-model FNAME\n");
fprintf(stderr, " Encodec model path (default: %s)\n", params.encodec_path.c_str());
fprintf(stderr, " -t TEXT, --text TEXT\n");
fprintf(stderr, " Input text prompt (default: %s)\n", args.text.c_str());
fprintf(stderr, " -l TEXT, --language TEXT\n");
fprintf(stderr, " Language for input text / output response (default: %s)\n", args.language.c_str());
fprintf(stderr, " -ts TASK, --task TASK\n");
fprintf(stderr, " Inferencing task (default: %s, accepts ['tts', 'stt', 'ns', 'sr'])\n", args.task.c_str());
fprintf(stderr, " -mode MODE, --modality MODE\n");
fprintf(stderr, " Modality for inferencing (default: %s, accepts ['ar+nar', 'nar-len'])\n", args.modality == MODALITY_NAR_LEN ? "nar-len" : "ar+nar");
fprintf(stderr, " -ms N, --max-steps N\n");
fprintf(stderr, " Max steps for `nar-len` (default: %i)\n", args.max_steps);
fprintf(stderr, " -md N, --max-duration N\n");
fprintf(stderr, " Max duration of the audio (default: %i)\n", args.max_duration);
fprintf(stderr, " -i FNAME, --input FNAME\n");
fprintf(stderr, " Input prompt wav (default: %s)\n", args.prompt_path.c_str());
fprintf(stderr, " -o FNAME, --output FNAME\n");
fprintf(stderr, " Output audio wav (default: %s)\n", args.output_path.c_str());
fprintf(stderr, "\n");
}
bool vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
for ( int i = 1; i < argc; i++ ) {

View File

@ -339,7 +339,7 @@ class AR_NAR(Base):
null_prom = [ None for _ in range(batch_size) ]
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc=f"NAR Masked Level {level}", disable=disable_tqdm)
for timestep in iterator:
for step, timestep in enumerate(iterator):
# update previous list of tokens
prev_list = resps_list
# ramp down over time
@ -348,8 +348,10 @@ class AR_NAR(Base):
noise_p = math.cos( timestep * math.pi * 0.5 )
# proportion of tokens to remask
remask_p = 1.0 / (max_steps * 2) if remasking else 0
mask_p = noise_p + remask_p
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
if vc_list is None or timestep >= vc_threshold:
# mask off inputs

View File

@ -275,7 +275,7 @@ class AR_NAR_V2(Base_V2):
null_prom = [ None for _ in range(batch_size) ]
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
for timestep in iterator:
for step, timestep in enumerate(iterator):
# update previous list of tokens
prev_list = resps_list
# ramp down over time
@ -284,8 +284,9 @@ class AR_NAR_V2(Base_V2):
noise_p = math.cos( timestep * math.pi * 0.5 )
# proportion of tokens to remask
remask_p = 1.0 / (max_steps * 2) if remasking else 0
mask_p = noise_p + remask_p
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=0 ).indices for score, seq_len in zip(scores, len_list) ]
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=0 ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
# mask off inputs