This commit is contained in:
mrq 2025-04-05 18:20:46 -05:00
parent 4a909ceff8
commit b6692ce3de
2 changed files with 17 additions and 24 deletions

View File

@ -386,32 +386,22 @@ std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<to
} }
std::vector<float> soft_max( int n_logits, const float* logits ) { std::vector<float> soft_max( int n_logits, const float* logits ) {
std::vector<float> res( n_logits, 0.0f ); std::vector<float> res(n_logits, 0.0f);
std::vector<float> expd( n_logits, 0.0f );
float denom = 0.0f; float denom = 0.0f;
for ( auto i = 0; i < n_logits; ++i ) { float max_logit = logits[0];
expd[i] = expf( logits[i] ); for (int i = 1; i < n_logits; ++i) {
denom += expd[i]; max_logit = std::max(max_logit, logits[i]);
}
// to-do: assert denom != 0.0f
for ( auto i = 0; i < n_logits; ++i ) {
res[i] = expd[i] / denom;
} }
return res; for (int i = 0; i < n_logits; ++i) {
} res[i] = std::exp(logits[i] - max_logit);
denom += res[i];
std::vector<float> log_soft_max( int n_logits, const float* logits ) {
std::vector<float> res( n_logits, 0.0f );
float denom = 0.0f;
for ( auto i = 0; i < n_logits; ++i ) {
denom += logits[i];
} }
// to-do: assert denom != 0.0f
for ( auto i = 0; i < n_logits; ++i ) { float inv_denom = 1.0f / denom;
res[i] = logits[i] / denom; for (int i = 0; i < n_logits; ++i) {
res[i] *= inv_denom;
} }
return res; return res;
@ -661,7 +651,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
std::vector<score_t> sorted_scores( n_outputs ); std::vector<score_t> sorted_scores( n_outputs );
for ( auto i = 0; i < n_outputs; ++i ) sorted_scores[i] = { i, scores[i] }; for ( auto i = 0; i < n_outputs; ++i ) sorted_scores[i] = { i, scores[i] };
std::sort(sorted_scores.begin(), sorted_scores.end()); std::sort(sorted_scores.begin(), sorted_scores.end());
std::reverse(sorted_scores.begin(), sorted_scores.end()); // std::reverse(sorted_scores.begin(), sorted_scores.end());
// and top-k pick the worst scores // and top-k pick the worst scores
for ( auto i = 0; i < n_masked_tokens; ++i ) { for ( auto i = 0; i < n_masked_tokens; ++i ) {
@ -717,6 +707,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
for ( auto idx = 0; idx < n_outputs; ++idx ) { for ( auto idx = 0; idx < n_outputs; ++idx ) {
// skip if not masked // skip if not masked
if ( !is_masked[idx] ) { if ( !is_masked[idx] ) {
scores[idx] = 0.0;
continue; continue;
} }
@ -750,15 +741,17 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
// update score if it was masked // update score if it was masked
// this is actually wrong // this is actually wrong
// scores[idx] = 1.0f - softmaxed[t]; // invert so we pick the worst tokens later scores[idx] = 1.0 - softmaxed[t]; // invert so we pick the worst tokens later
// this seems to work better // this seems to work better
/*
float entropy = 0.f; float entropy = 0.f;
for (int v = 0; v < n_vocab; ++v ) { for (int v = 0; v < n_vocab; ++v ) {
float p = softmaxed[v]; float p = softmaxed[v];
if (p > 0) entropy -= p * std::log(p + 1e-9); if (p > 0) entropy -= p * std::log(p + 1e-9);
} }
scores[idx] = entropy / std::log(n_vocab); // normalize [01] scores[idx] = entropy / std::log(n_vocab); // normalize [01]
*/
} }
llama_sampler_free(smpl); llama_sampler_free(smpl);

View File

@ -157,7 +157,7 @@ def _load_model(device="cuda", backend=None, dtype=None):
if not backend: if not backend:
backend = cfg.audio_backend backend = cfg.audio_backend
if ERRORED_BACKENDS[backend]: if ERRORED_BACKENDS.get(backend, None):
raise ERRORED_BACKENDS[backend] raise ERRORED_BACKENDS[backend]
if cfg.inference.amp: if cfg.inference.amp: