ugh
This commit is contained in:
parent
532200de2a
commit
2b4d783299
|
@ -33,7 +33,7 @@ Run `make`.
|
|||
* [x] load audio from disk
|
||||
* [x] encode audio
|
||||
* [x] sum embeddings for the `prom` and prior `resp`s
|
||||
* [ ] working `AR` output
|
||||
* [x] working `AR` output
|
||||
* [x] `AR` sampling
|
||||
* [ ] working `NAR-len` output
|
||||
* [x] `NAR-len` sampling
|
||||
|
|
|
@ -339,7 +339,7 @@ std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
|
|||
float denom = 0.0f;
|
||||
|
||||
for ( auto i = 0; i < n_logits; ++i ) {
|
||||
expd[i] = exp( logits[i] );
|
||||
expd[i] = expf( logits[i] );
|
||||
denom += expd[i];
|
||||
}
|
||||
// to-do: assert denom != 0.0f
|
||||
|
@ -493,7 +493,11 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
|||
|
||||
// update model's output heads / causal mode
|
||||
llama_set_output_head( model, io.head );
|
||||
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
|
||||
// to-do: figure this out......
|
||||
{
|
||||
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
|
||||
// *const_cast<bool*>(&model->hparams.causal_attn) = true; // force set this
|
||||
}
|
||||
|
||||
std::vector<llama_token> output_tokens;
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
@ -557,14 +561,14 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
|||
float sampling_cfg_strength = timestep * cfg_strength;
|
||||
|
||||
float noise_p = cos( timestep * PI * 0.5f );
|
||||
float remask_p = 0.0f; // 0.5f / steps;
|
||||
float remask_p = 0.5f / steps;
|
||||
|
||||
int32_t n_masked_tokens = (noise_p + remask_p) * seq_len;
|
||||
if ( n_masked_tokens < 1 ) {
|
||||
n_masked_tokens = 1;
|
||||
}
|
||||
if ( n_masked_tokens > n_outputs ) {
|
||||
n_masked_tokens = n_outputs;
|
||||
if ( n_masked_tokens > (n_outputs - step) ) {
|
||||
n_masked_tokens = (n_outputs - step);
|
||||
}
|
||||
|
||||
// masked mask
|
||||
|
@ -582,7 +586,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
|||
is_masked[idx] = true;
|
||||
}
|
||||
|
||||
if ( verbose ) print_tokens( output_tokens, "Masked tokens:" );
|
||||
if ( verbose ) print_tokens( output_tokens, "Masked tokens: " );
|
||||
|
||||
// update batch
|
||||
// to-do: only update the embeddings instead
|
||||
|
@ -602,10 +606,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
|||
llama_kv_cache_clear(ctx); // necessary for many reasons
|
||||
// copy null probabilities
|
||||
std::vector<float> null_logits(n_outputs * n_vocab, 0.0f);
|
||||
// to-do: copy once
|
||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||
memcpy( &null_logits[idx * n_vocab], llama_get_logits_ith( ctx, null_batch.n_tokens - n_outputs + idx ), sizeof(float) * n_vocab );
|
||||
}
|
||||
memcpy( null_logits.data(), llama_get_logits( ctx ), sizeof(float) * n_vocab * n_outputs );
|
||||
|
||||
// decode
|
||||
if ( llama_decode(ctx, batch) ) {
|
||||
|
@ -623,23 +624,32 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
|||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
|
||||
|
||||
// to-do: figure out why all logits are the same for each token......
|
||||
// "reverse" iterate from backwards indexing
|
||||
auto* logits = llama_get_logits( ctx );
|
||||
|
||||
/*
|
||||
// perform CFG sampling
|
||||
for ( auto i = 0; i < n_vocab * n_outputs; ++i ) {
|
||||
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
|
||||
}
|
||||
*/
|
||||
|
||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||
// skip if not masked
|
||||
if ( !is_masked[idx] )
|
||||
if ( !is_masked[idx] ) {
|
||||
scores[idx] = 1.0f;
|
||||
continue;
|
||||
// ensures only tokens within our designated range are used
|
||||
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
|
||||
auto* null_logit = &null_logits[idx];
|
||||
}
|
||||
|
||||
auto* logit = &logits[idx * n_vocab];
|
||||
auto* null_logit = &null_logits[idx * n_vocab];
|
||||
|
||||
// perform softmax before modifying logits
|
||||
std::vector<float> softmaxed = soft_max( n_vocab, logits );
|
||||
std::vector<float> softmaxed = soft_max( n_vocab, logit );
|
||||
|
||||
// perform CFG sampling
|
||||
for ( auto i = 0; i < n_vocab; ++i ) {
|
||||
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
|
||||
logit[i] = null_logit[i] + (logit[i] - null_logit[i]) * cfg_strength;
|
||||
}
|
||||
|
||||
// sample ith token
|
||||
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx );
|
||||
// store token if it was masked
|
||||
|
@ -654,23 +664,34 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
|||
}
|
||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
|
||||
output_tokens.reserve(n_outputs);
|
||||
output_tokens.clear();
|
||||
output_tokens.resize(n_outputs);
|
||||
// do one step on many tokens
|
||||
if ( llama_decode(ctx, batch) ) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return output_tokens;
|
||||
}
|
||||
llama_kv_cache_clear(ctx); // necessary for many reasons
|
||||
// to-do: figure out why all logits are the same for each token......
|
||||
// "reverse" iterate from backwards indexing
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(1));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
|
||||
|
||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||
// sample ith token
|
||||
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
|
||||
|
||||
// store token
|
||||
output_tokens.emplace_back(t);
|
||||
output_tokens[idx] = t;
|
||||
}
|
||||
if ( verbose ) print_tokens( output_tokens );
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
}
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
@ -738,15 +759,12 @@ int main( int argc, char** argv ) {
|
|||
// initialize the sampler
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
llama_sampler * smpl_ar = llama_sampler_chain_init(sparams);
|
||||
llama_sampler * smpl_nar = llama_sampler_chain_init(sparams);
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(0));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(1.0, 1));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
|
||||
|
||||
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(0));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
|
||||
|
||||
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
|
||||
if (!ectx) {
|
||||
|
@ -790,10 +808,10 @@ int main( int argc, char** argv ) {
|
|||
// NAR-len demasking
|
||||
if ( modality == MODALITY_NAR_LEN ) {
|
||||
// inference len
|
||||
int len = 75;
|
||||
int len = 0;
|
||||
if ( !len ) {
|
||||
input.task = "len";
|
||||
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN );
|
||||
output_tokens = generate( ctx, model, smpl, input, io_map, 5, INFERENCE_MODE_LEN );
|
||||
{
|
||||
int digit = 1;
|
||||
for (int i = output_tokens.size() - 1; i >= 0; i--) {
|
||||
|
@ -815,7 +833,8 @@ int main( int argc, char** argv ) {
|
|||
input.task = "tts";
|
||||
for ( auto l = 0; l < 8; ++l ) {
|
||||
input.rvq_l = l;
|
||||
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
|
||||
output_tokens = generate( ctx, model, smpl, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
|
||||
if ( l == 0 ) input.resp.clear();
|
||||
input.resp.emplace_back( output_tokens );
|
||||
}
|
||||
// AR+NAR
|
||||
|
@ -823,7 +842,7 @@ int main( int argc, char** argv ) {
|
|||
input.task = "tts";
|
||||
for ( auto l = 0; l < 8; ++l ) {
|
||||
input.rvq_l = l;
|
||||
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
|
||||
output_tokens = generate( ctx, model, smpl, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
|
||||
input.resp.emplace_back( output_tokens );
|
||||
}
|
||||
}
|
||||
|
@ -835,8 +854,7 @@ int main( int argc, char** argv ) {
|
|||
// cleanup
|
||||
encodec_free(ectx);
|
||||
|
||||
llama_sampler_free(smpl_nar);
|
||||
llama_sampler_free(smpl_ar);
|
||||
llama_sampler_free(smpl);
|
||||
|
||||
llama_free(ctx);
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user