From 2b4d783299355d97ad87030951632b34f2fc2b74 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 23 Dec 2024 23:42:44 -0600 Subject: [PATCH] ugh --- vall_e.cpp/README.md | 4 +- vall_e.cpp/vall_e.cpp | 92 ++++++++++++++++++++++++++----------------- 2 files changed, 57 insertions(+), 39 deletions(-) diff --git a/vall_e.cpp/README.md b/vall_e.cpp/README.md index 914783a..03c3cc1 100644 --- a/vall_e.cpp/README.md +++ b/vall_e.cpp/README.md @@ -33,7 +33,7 @@ Run `make`. * [x] load audio from disk * [x] encode audio * [x] sum embeddings for the `prom` and prior `resp`s -* [ ] working `AR` output +* [x] working `AR` output * [x] `AR` sampling * [ ] working `NAR-len` output * [x] `NAR-len` sampling @@ -41,4 +41,4 @@ Run `make`. * [x] `NAR` sampling * [x] decode audio to disk * [ ] a functional CLI -* [ ] actually make it work +* [ ] actually make it work \ No newline at end of file diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index aa279cd..84a6db6 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -339,7 +339,7 @@ std::vector VALL_E_API soft_max( int n_logits, const float* logits ) { float denom = 0.0f; for ( auto i = 0; i < n_logits; ++i ) { - expd[i] = exp( logits[i] ); + expd[i] = expf( logits[i] ); denom += expd[i]; } // to-do: assert denom != 0.0f @@ -493,7 +493,11 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m // update model's output heads / causal mode llama_set_output_head( model, io.head ); - llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0]) + // to-do: figure this out...... + { + llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0]) + // *const_cast(&model->hparams.causal_attn) = true; // force set this + } std::vector output_tokens; const auto t_main_start = ggml_time_us(); @@ -557,14 +561,14 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m float sampling_cfg_strength = timestep * cfg_strength; float noise_p = cos( timestep * PI * 0.5f ); - float remask_p = 0.0f; // 0.5f / steps; + float remask_p = 0.5f / steps; int32_t n_masked_tokens = (noise_p + remask_p) * seq_len; if ( n_masked_tokens < 1 ) { n_masked_tokens = 1; } - if ( n_masked_tokens > n_outputs ) { - n_masked_tokens = n_outputs; + if ( n_masked_tokens > (n_outputs - step) ) { + n_masked_tokens = (n_outputs - step); } // masked mask @@ -582,7 +586,7 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m is_masked[idx] = true; } - if ( verbose ) print_tokens( output_tokens, "Masked tokens:" ); + if ( verbose ) print_tokens( output_tokens, "Masked tokens: " ); // update batch // to-do: only update the embeddings instead @@ -602,10 +606,7 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m llama_kv_cache_clear(ctx); // necessary for many reasons // copy null probabilities std::vector null_logits(n_outputs * n_vocab, 0.0f); - // to-do: copy once - for ( auto idx = 0; idx < n_outputs; ++idx ) { - memcpy( &null_logits[idx * n_vocab], llama_get_logits_ith( ctx, null_batch.n_tokens - n_outputs + idx ), sizeof(float) * n_vocab ); - } + memcpy( null_logits.data(), llama_get_logits( ctx ), sizeof(float) * n_vocab * n_outputs ); // decode if ( llama_decode(ctx, batch) ) { @@ -623,23 +624,32 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130)); - // to-do: figure out why all logits are the same for each token...... - // "reverse" iterate from backwards indexing + auto* logits = llama_get_logits( ctx ); + + /* + // perform CFG sampling + for ( auto i = 0; i < n_vocab * n_outputs; ++i ) { + logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength; + } + */ + for ( auto idx = 0; idx < n_outputs; ++idx ) { // skip if not masked - if ( !is_masked[idx] ) + if ( !is_masked[idx] ) { + scores[idx] = 1.0f; continue; - // ensures only tokens within our designated range are used - auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx ); - auto* null_logit = &null_logits[idx]; + } + + auto* logit = &logits[idx * n_vocab]; + auto* null_logit = &null_logits[idx * n_vocab]; // perform softmax before modifying logits - std::vector softmaxed = soft_max( n_vocab, logits ); + std::vector softmaxed = soft_max( n_vocab, logit ); - // perform CFG sampling for ( auto i = 0; i < n_vocab; ++i ) { - logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength; + logit[i] = null_logit[i] + (logit[i] - null_logit[i]) * cfg_strength; } + // sample ith token auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx ); // store token if it was masked @@ -654,23 +664,34 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m } } else if ( mode == INFERENCE_MODE_NAR ) { // to-do: assert n_outputs == input.resp[rvq_l-1].size() - output_tokens.reserve(n_outputs); + output_tokens.clear(); + output_tokens.resize(n_outputs); // do one step on many tokens if ( llama_decode(ctx, batch) ) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return output_tokens; } llama_kv_cache_clear(ctx); // necessary for many reasons - // to-do: figure out why all logits are the same for each token...... - // "reverse" iterate from backwards indexing + + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(1)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130)); + for ( auto idx = 0; idx < n_outputs; ++idx ) { // sample ith token auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx); // store token - output_tokens.emplace_back(t); + output_tokens[idx] = t; } if ( verbose ) print_tokens( output_tokens ); + + llama_sampler_free(smpl); } const auto t_main_end = ggml_time_us(); @@ -738,15 +759,12 @@ int main( int argc, char** argv ) { // initialize the sampler auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; - llama_sampler * smpl_ar = llama_sampler_chain_init(sparams); - llama_sampler * smpl_nar = llama_sampler_chain_init(sparams); + llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(0)); - llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(1.0, 1)); - llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0)); - llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130)); - - llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy()); + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(0)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130)); struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl); if (!ectx) { @@ -790,10 +808,10 @@ int main( int argc, char** argv ) { // NAR-len demasking if ( modality == MODALITY_NAR_LEN ) { // inference len - int len = 75; + int len = 0; if ( !len ) { input.task = "len"; - output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN ); + output_tokens = generate( ctx, model, smpl, input, io_map, 5, INFERENCE_MODE_LEN ); { int digit = 1; for (int i = output_tokens.size() - 1; i >= 0; i--) { @@ -815,7 +833,8 @@ int main( int argc, char** argv ) { input.task = "tts"; for ( auto l = 0; l < 8; ++l ) { input.rvq_l = l; - output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR ); + output_tokens = generate( ctx, model, smpl, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR ); + if ( l == 0 ) input.resp.clear(); input.resp.emplace_back( output_tokens ); } // AR+NAR @@ -823,7 +842,7 @@ int main( int argc, char** argv ) { input.task = "tts"; for ( auto l = 0; l < 8; ++l ) { input.rvq_l = l; - output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR ); + output_tokens = generate( ctx, model, smpl, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR ); input.resp.emplace_back( output_tokens ); } } @@ -835,8 +854,7 @@ int main( int argc, char** argv ) { // cleanup encodec_free(ectx); - llama_sampler_free(smpl_nar); - llama_sampler_free(smpl_ar); + llama_sampler_free(smpl); llama_free(ctx);