This commit is contained in:
mrq 2024-12-23 23:42:44 -06:00
parent 532200de2a
commit 2b4d783299
2 changed files with 57 additions and 39 deletions

View File

@ -33,7 +33,7 @@ Run `make`.
* [x] load audio from disk * [x] load audio from disk
* [x] encode audio * [x] encode audio
* [x] sum embeddings for the `prom` and prior `resp`s * [x] sum embeddings for the `prom` and prior `resp`s
* [ ] working `AR` output * [x] working `AR` output
* [x] `AR` sampling * [x] `AR` sampling
* [ ] working `NAR-len` output * [ ] working `NAR-len` output
* [x] `NAR-len` sampling * [x] `NAR-len` sampling

View File

@ -339,7 +339,7 @@ std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
float denom = 0.0f; float denom = 0.0f;
for ( auto i = 0; i < n_logits; ++i ) { for ( auto i = 0; i < n_logits; ++i ) {
expd[i] = exp( logits[i] ); expd[i] = expf( logits[i] );
denom += expd[i]; denom += expd[i];
} }
// to-do: assert denom != 0.0f // to-do: assert denom != 0.0f
@ -493,7 +493,11 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
// update model's output heads / causal mode // update model's output heads / causal mode
llama_set_output_head( model, io.head ); llama_set_output_head( model, io.head );
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0]) // to-do: figure this out......
{
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
// *const_cast<bool*>(&model->hparams.causal_attn) = true; // force set this
}
std::vector<llama_token> output_tokens; std::vector<llama_token> output_tokens;
const auto t_main_start = ggml_time_us(); const auto t_main_start = ggml_time_us();
@ -557,14 +561,14 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
float sampling_cfg_strength = timestep * cfg_strength; float sampling_cfg_strength = timestep * cfg_strength;
float noise_p = cos( timestep * PI * 0.5f ); float noise_p = cos( timestep * PI * 0.5f );
float remask_p = 0.0f; // 0.5f / steps; float remask_p = 0.5f / steps;
int32_t n_masked_tokens = (noise_p + remask_p) * seq_len; int32_t n_masked_tokens = (noise_p + remask_p) * seq_len;
if ( n_masked_tokens < 1 ) { if ( n_masked_tokens < 1 ) {
n_masked_tokens = 1; n_masked_tokens = 1;
} }
if ( n_masked_tokens > n_outputs ) { if ( n_masked_tokens > (n_outputs - step) ) {
n_masked_tokens = n_outputs; n_masked_tokens = (n_outputs - step);
} }
// masked mask // masked mask
@ -582,7 +586,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
is_masked[idx] = true; is_masked[idx] = true;
} }
if ( verbose ) print_tokens( output_tokens, "Masked tokens:" ); if ( verbose ) print_tokens( output_tokens, "Masked tokens: " );
// update batch // update batch
// to-do: only update the embeddings instead // to-do: only update the embeddings instead
@ -602,10 +606,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
llama_kv_cache_clear(ctx); // necessary for many reasons llama_kv_cache_clear(ctx); // necessary for many reasons
// copy null probabilities // copy null probabilities
std::vector<float> null_logits(n_outputs * n_vocab, 0.0f); std::vector<float> null_logits(n_outputs * n_vocab, 0.0f);
// to-do: copy once memcpy( null_logits.data(), llama_get_logits( ctx ), sizeof(float) * n_vocab * n_outputs );
for ( auto idx = 0; idx < n_outputs; ++idx ) {
memcpy( &null_logits[idx * n_vocab], llama_get_logits_ith( ctx, null_batch.n_tokens - n_outputs + idx ), sizeof(float) * n_vocab );
}
// decode // decode
if ( llama_decode(ctx, batch) ) { if ( llama_decode(ctx, batch) ) {
@ -623,23 +624,32 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
// to-do: figure out why all logits are the same for each token...... auto* logits = llama_get_logits( ctx );
// "reverse" iterate from backwards indexing
/*
// perform CFG sampling
for ( auto i = 0; i < n_vocab * n_outputs; ++i ) {
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
}
*/
for ( auto idx = 0; idx < n_outputs; ++idx ) { for ( auto idx = 0; idx < n_outputs; ++idx ) {
// skip if not masked // skip if not masked
if ( !is_masked[idx] ) if ( !is_masked[idx] ) {
scores[idx] = 1.0f;
continue; continue;
// ensures only tokens within our designated range are used }
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
auto* null_logit = &null_logits[idx]; auto* logit = &logits[idx * n_vocab];
auto* null_logit = &null_logits[idx * n_vocab];
// perform softmax before modifying logits // perform softmax before modifying logits
std::vector<float> softmaxed = soft_max( n_vocab, logits ); std::vector<float> softmaxed = soft_max( n_vocab, logit );
// perform CFG sampling
for ( auto i = 0; i < n_vocab; ++i ) { for ( auto i = 0; i < n_vocab; ++i ) {
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength; logit[i] = null_logit[i] + (logit[i] - null_logit[i]) * cfg_strength;
} }
// sample ith token // sample ith token
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx ); auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx );
// store token if it was masked // store token if it was masked
@ -654,23 +664,34 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
} }
} else if ( mode == INFERENCE_MODE_NAR ) { } else if ( mode == INFERENCE_MODE_NAR ) {
// to-do: assert n_outputs == input.resp[rvq_l-1].size() // to-do: assert n_outputs == input.resp[rvq_l-1].size()
output_tokens.reserve(n_outputs); output_tokens.clear();
output_tokens.resize(n_outputs);
// do one step on many tokens // do one step on many tokens
if ( llama_decode(ctx, batch) ) { if ( llama_decode(ctx, batch) ) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return output_tokens; return output_tokens;
} }
llama_kv_cache_clear(ctx); // necessary for many reasons llama_kv_cache_clear(ctx); // necessary for many reasons
// to-do: figure out why all logits are the same for each token......
// "reverse" iterate from backwards indexing auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(1));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
for ( auto idx = 0; idx < n_outputs; ++idx ) { for ( auto idx = 0; idx < n_outputs; ++idx ) {
// sample ith token // sample ith token
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx); auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
// store token // store token
output_tokens.emplace_back(t); output_tokens[idx] = t;
} }
if ( verbose ) print_tokens( output_tokens ); if ( verbose ) print_tokens( output_tokens );
llama_sampler_free(smpl);
} }
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();
@ -738,15 +759,12 @@ int main( int argc, char** argv ) {
// initialize the sampler // initialize the sampler
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false; sparams.no_perf = false;
llama_sampler * smpl_ar = llama_sampler_chain_init(sparams); llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler * smpl_nar = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(0)); llama_sampler_chain_add(smpl, llama_sampler_init_top_k(0));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(1.0, 1)); llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl); struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
if (!ectx) { if (!ectx) {
@ -790,10 +808,10 @@ int main( int argc, char** argv ) {
// NAR-len demasking // NAR-len demasking
if ( modality == MODALITY_NAR_LEN ) { if ( modality == MODALITY_NAR_LEN ) {
// inference len // inference len
int len = 75; int len = 0;
if ( !len ) { if ( !len ) {
input.task = "len"; input.task = "len";
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN ); output_tokens = generate( ctx, model, smpl, input, io_map, 5, INFERENCE_MODE_LEN );
{ {
int digit = 1; int digit = 1;
for (int i = output_tokens.size() - 1; i >= 0; i--) { for (int i = output_tokens.size() - 1; i >= 0; i--) {
@ -815,7 +833,8 @@ int main( int argc, char** argv ) {
input.task = "tts"; input.task = "tts";
for ( auto l = 0; l < 8; ++l ) { for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l; input.rvq_l = l;
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR ); output_tokens = generate( ctx, model, smpl, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
if ( l == 0 ) input.resp.clear();
input.resp.emplace_back( output_tokens ); input.resp.emplace_back( output_tokens );
} }
// AR+NAR // AR+NAR
@ -823,7 +842,7 @@ int main( int argc, char** argv ) {
input.task = "tts"; input.task = "tts";
for ( auto l = 0; l < 8; ++l ) { for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l; input.rvq_l = l;
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR ); output_tokens = generate( ctx, model, smpl, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
input.resp.emplace_back( output_tokens ); input.resp.emplace_back( output_tokens );
} }
} }
@ -835,8 +854,7 @@ int main( int argc, char** argv ) {
// cleanup // cleanup
encodec_free(ectx); encodec_free(ectx);
llama_sampler_free(smpl_nar); llama_sampler_free(smpl);
llama_sampler_free(smpl_ar);
llama_free(ctx); llama_free(ctx);