This commit is contained in:
mrq 2024-12-23 23:42:44 -06:00
parent 532200de2a
commit 2b4d783299
2 changed files with 57 additions and 39 deletions

View File

@ -33,7 +33,7 @@ Run `make`.
* [x] load audio from disk
* [x] encode audio
* [x] sum embeddings for the `prom` and prior `resp`s
* [ ] working `AR` output
* [x] working `AR` output
* [x] `AR` sampling
* [ ] working `NAR-len` output
* [x] `NAR-len` sampling
@ -41,4 +41,4 @@ Run `make`.
* [x] `NAR` sampling
* [x] decode audio to disk
* [ ] a functional CLI
* [ ] actually make it work
* [ ] actually make it work

View File

@ -339,7 +339,7 @@ std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
float denom = 0.0f;
for ( auto i = 0; i < n_logits; ++i ) {
expd[i] = exp( logits[i] );
expd[i] = expf( logits[i] );
denom += expd[i];
}
// to-do: assert denom != 0.0f
@ -493,7 +493,11 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
// update model's output heads / causal mode
llama_set_output_head( model, io.head );
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
// to-do: figure this out......
{
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
// *const_cast<bool*>(&model->hparams.causal_attn) = true; // force set this
}
std::vector<llama_token> output_tokens;
const auto t_main_start = ggml_time_us();
@ -557,14 +561,14 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
float sampling_cfg_strength = timestep * cfg_strength;
float noise_p = cos( timestep * PI * 0.5f );
float remask_p = 0.0f; // 0.5f / steps;
float remask_p = 0.5f / steps;
int32_t n_masked_tokens = (noise_p + remask_p) * seq_len;
if ( n_masked_tokens < 1 ) {
n_masked_tokens = 1;
}
if ( n_masked_tokens > n_outputs ) {
n_masked_tokens = n_outputs;
if ( n_masked_tokens > (n_outputs - step) ) {
n_masked_tokens = (n_outputs - step);
}
// masked mask
@ -582,7 +586,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
is_masked[idx] = true;
}
if ( verbose ) print_tokens( output_tokens, "Masked tokens:" );
if ( verbose ) print_tokens( output_tokens, "Masked tokens: " );
// update batch
// to-do: only update the embeddings instead
@ -602,10 +606,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
llama_kv_cache_clear(ctx); // necessary for many reasons
// copy null probabilities
std::vector<float> null_logits(n_outputs * n_vocab, 0.0f);
// to-do: copy once
for ( auto idx = 0; idx < n_outputs; ++idx ) {
memcpy( &null_logits[idx * n_vocab], llama_get_logits_ith( ctx, null_batch.n_tokens - n_outputs + idx ), sizeof(float) * n_vocab );
}
memcpy( null_logits.data(), llama_get_logits( ctx ), sizeof(float) * n_vocab * n_outputs );
// decode
if ( llama_decode(ctx, batch) ) {
@ -623,23 +624,32 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
// to-do: figure out why all logits are the same for each token......
// "reverse" iterate from backwards indexing
auto* logits = llama_get_logits( ctx );
/*
// perform CFG sampling
for ( auto i = 0; i < n_vocab * n_outputs; ++i ) {
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
}
*/
for ( auto idx = 0; idx < n_outputs; ++idx ) {
// skip if not masked
if ( !is_masked[idx] )
if ( !is_masked[idx] ) {
scores[idx] = 1.0f;
continue;
// ensures only tokens within our designated range are used
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
auto* null_logit = &null_logits[idx];
}
auto* logit = &logits[idx * n_vocab];
auto* null_logit = &null_logits[idx * n_vocab];
// perform softmax before modifying logits
std::vector<float> softmaxed = soft_max( n_vocab, logits );
std::vector<float> softmaxed = soft_max( n_vocab, logit );
// perform CFG sampling
for ( auto i = 0; i < n_vocab; ++i ) {
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
logit[i] = null_logit[i] + (logit[i] - null_logit[i]) * cfg_strength;
}
// sample ith token
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx );
// store token if it was masked
@ -654,23 +664,34 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
}
} else if ( mode == INFERENCE_MODE_NAR ) {
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
output_tokens.reserve(n_outputs);
output_tokens.clear();
output_tokens.resize(n_outputs);
// do one step on many tokens
if ( llama_decode(ctx, batch) ) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return output_tokens;
}
llama_kv_cache_clear(ctx); // necessary for many reasons
// to-do: figure out why all logits are the same for each token......
// "reverse" iterate from backwards indexing
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(1));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
for ( auto idx = 0; idx < n_outputs; ++idx ) {
// sample ith token
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
// store token
output_tokens.emplace_back(t);
output_tokens[idx] = t;
}
if ( verbose ) print_tokens( output_tokens );
llama_sampler_free(smpl);
}
const auto t_main_end = ggml_time_us();
@ -738,15 +759,12 @@ int main( int argc, char** argv ) {
// initialize the sampler
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * smpl_ar = llama_sampler_chain_init(sparams);
llama_sampler * smpl_nar = llama_sampler_chain_init(sparams);
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(0));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(1.0, 1));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(0));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
if (!ectx) {
@ -790,10 +808,10 @@ int main( int argc, char** argv ) {
// NAR-len demasking
if ( modality == MODALITY_NAR_LEN ) {
// inference len
int len = 75;
int len = 0;
if ( !len ) {
input.task = "len";
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN );
output_tokens = generate( ctx, model, smpl, input, io_map, 5, INFERENCE_MODE_LEN );
{
int digit = 1;
for (int i = output_tokens.size() - 1; i >= 0; i--) {
@ -815,7 +833,8 @@ int main( int argc, char** argv ) {
input.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l;
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
output_tokens = generate( ctx, model, smpl, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
if ( l == 0 ) input.resp.clear();
input.resp.emplace_back( output_tokens );
}
// AR+NAR
@ -823,7 +842,7 @@ int main( int argc, char** argv ) {
input.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l;
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
output_tokens = generate( ctx, model, smpl, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
input.resp.emplace_back( output_tokens );
}
}
@ -835,8 +854,7 @@ int main( int argc, char** argv ) {
// cleanup
encodec_free(ectx);
llama_sampler_free(smpl_nar);
llama_sampler_free(smpl_ar);
llama_sampler_free(smpl);
llama_free(ctx);