From b6692ce3de4a2e2e431788db6cb2b4cca721df4a Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 5 Apr 2025 18:20:46 -0500 Subject: [PATCH] ugh --- vall_e.cpp/vall_e.cpp | 39 ++++++++++++++++----------------------- vall_e/emb/qnt.py | 2 +- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index a824544..db539a3 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -386,32 +386,22 @@ std::vector> sum_embeddings( const std::vector soft_max( int n_logits, const float* logits ) { - std::vector res( n_logits, 0.0f ); - std::vector expd( n_logits, 0.0f ); + std::vector res(n_logits, 0.0f); float denom = 0.0f; - for ( auto i = 0; i < n_logits; ++i ) { - expd[i] = expf( logits[i] ); - denom += expd[i]; - } - // to-do: assert denom != 0.0f - for ( auto i = 0; i < n_logits; ++i ) { - res[i] = expd[i] / denom; + float max_logit = logits[0]; + for (int i = 1; i < n_logits; ++i) { + max_logit = std::max(max_logit, logits[i]); } - return res; -} - -std::vector log_soft_max( int n_logits, const float* logits ) { - std::vector res( n_logits, 0.0f ); - float denom = 0.0f; - - for ( auto i = 0; i < n_logits; ++i ) { - denom += logits[i]; + for (int i = 0; i < n_logits; ++i) { + res[i] = std::exp(logits[i] - max_logit); + denom += res[i]; } - // to-do: assert denom != 0.0f - for ( auto i = 0; i < n_logits; ++i ) { - res[i] = logits[i] / denom; + + float inv_denom = 1.0f / denom; + for (int i = 0; i < n_logits; ++i) { + res[i] *= inv_denom; } return res; @@ -661,7 +651,7 @@ std::vector generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i std::vector sorted_scores( n_outputs ); for ( auto i = 0; i < n_outputs; ++i ) sorted_scores[i] = { i, scores[i] }; std::sort(sorted_scores.begin(), sorted_scores.end()); - std::reverse(sorted_scores.begin(), sorted_scores.end()); + // std::reverse(sorted_scores.begin(), sorted_scores.end()); // and top-k pick the worst scores for ( auto i = 0; i < n_masked_tokens; ++i ) { @@ -717,6 +707,7 @@ std::vector generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i for ( auto idx = 0; idx < n_outputs; ++idx ) { // skip if not masked if ( !is_masked[idx] ) { + scores[idx] = 0.0; continue; } @@ -750,15 +741,17 @@ std::vector generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i // update score if it was masked // this is actually wrong - // scores[idx] = 1.0f - softmaxed[t]; // invert so we pick the worst tokens later + scores[idx] = 1.0 - softmaxed[t]; // invert so we pick the worst tokens later // this seems to work better + /* float entropy = 0.f; for (int v = 0; v < n_vocab; ++v ) { float p = softmaxed[v]; if (p > 0) entropy -= p * std::log(p + 1e-9); } scores[idx] = entropy / std::log(n_vocab); // normalize [0–1] + */ } llama_sampler_free(smpl); diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 0248fea..0d7af5e 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -157,7 +157,7 @@ def _load_model(device="cuda", backend=None, dtype=None): if not backend: backend = cfg.audio_backend - if ERRORED_BACKENDS[backend]: + if ERRORED_BACKENDS.get(backend, None): raise ERRORED_BACKENDS[backend] if cfg.inference.amp: