ugh
This commit is contained in:
parent
532200de2a
commit
2b4d783299
|
@ -33,7 +33,7 @@ Run `make`.
|
||||||
* [x] load audio from disk
|
* [x] load audio from disk
|
||||||
* [x] encode audio
|
* [x] encode audio
|
||||||
* [x] sum embeddings for the `prom` and prior `resp`s
|
* [x] sum embeddings for the `prom` and prior `resp`s
|
||||||
* [ ] working `AR` output
|
* [x] working `AR` output
|
||||||
* [x] `AR` sampling
|
* [x] `AR` sampling
|
||||||
* [ ] working `NAR-len` output
|
* [ ] working `NAR-len` output
|
||||||
* [x] `NAR-len` sampling
|
* [x] `NAR-len` sampling
|
||||||
|
|
|
@ -339,7 +339,7 @@ std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
|
||||||
float denom = 0.0f;
|
float denom = 0.0f;
|
||||||
|
|
||||||
for ( auto i = 0; i < n_logits; ++i ) {
|
for ( auto i = 0; i < n_logits; ++i ) {
|
||||||
expd[i] = exp( logits[i] );
|
expd[i] = expf( logits[i] );
|
||||||
denom += expd[i];
|
denom += expd[i];
|
||||||
}
|
}
|
||||||
// to-do: assert denom != 0.0f
|
// to-do: assert denom != 0.0f
|
||||||
|
@ -493,7 +493,11 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
||||||
|
|
||||||
// update model's output heads / causal mode
|
// update model's output heads / causal mode
|
||||||
llama_set_output_head( model, io.head );
|
llama_set_output_head( model, io.head );
|
||||||
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
|
// to-do: figure this out......
|
||||||
|
{
|
||||||
|
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
|
||||||
|
// *const_cast<bool*>(&model->hparams.causal_attn) = true; // force set this
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token> output_tokens;
|
std::vector<llama_token> output_tokens;
|
||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
@ -557,14 +561,14 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
||||||
float sampling_cfg_strength = timestep * cfg_strength;
|
float sampling_cfg_strength = timestep * cfg_strength;
|
||||||
|
|
||||||
float noise_p = cos( timestep * PI * 0.5f );
|
float noise_p = cos( timestep * PI * 0.5f );
|
||||||
float remask_p = 0.0f; // 0.5f / steps;
|
float remask_p = 0.5f / steps;
|
||||||
|
|
||||||
int32_t n_masked_tokens = (noise_p + remask_p) * seq_len;
|
int32_t n_masked_tokens = (noise_p + remask_p) * seq_len;
|
||||||
if ( n_masked_tokens < 1 ) {
|
if ( n_masked_tokens < 1 ) {
|
||||||
n_masked_tokens = 1;
|
n_masked_tokens = 1;
|
||||||
}
|
}
|
||||||
if ( n_masked_tokens > n_outputs ) {
|
if ( n_masked_tokens > (n_outputs - step) ) {
|
||||||
n_masked_tokens = n_outputs;
|
n_masked_tokens = (n_outputs - step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// masked mask
|
// masked mask
|
||||||
|
@ -582,7 +586,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
||||||
is_masked[idx] = true;
|
is_masked[idx] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( verbose ) print_tokens( output_tokens, "Masked tokens:" );
|
if ( verbose ) print_tokens( output_tokens, "Masked tokens: " );
|
||||||
|
|
||||||
// update batch
|
// update batch
|
||||||
// to-do: only update the embeddings instead
|
// to-do: only update the embeddings instead
|
||||||
|
@ -602,10 +606,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
||||||
llama_kv_cache_clear(ctx); // necessary for many reasons
|
llama_kv_cache_clear(ctx); // necessary for many reasons
|
||||||
// copy null probabilities
|
// copy null probabilities
|
||||||
std::vector<float> null_logits(n_outputs * n_vocab, 0.0f);
|
std::vector<float> null_logits(n_outputs * n_vocab, 0.0f);
|
||||||
// to-do: copy once
|
memcpy( null_logits.data(), llama_get_logits( ctx ), sizeof(float) * n_vocab * n_outputs );
|
||||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
|
||||||
memcpy( &null_logits[idx * n_vocab], llama_get_logits_ith( ctx, null_batch.n_tokens - n_outputs + idx ), sizeof(float) * n_vocab );
|
|
||||||
}
|
|
||||||
|
|
||||||
// decode
|
// decode
|
||||||
if ( llama_decode(ctx, batch) ) {
|
if ( llama_decode(ctx, batch) ) {
|
||||||
|
@ -623,23 +624,32 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
|
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
|
||||||
|
|
||||||
// to-do: figure out why all logits are the same for each token......
|
auto* logits = llama_get_logits( ctx );
|
||||||
// "reverse" iterate from backwards indexing
|
|
||||||
|
/*
|
||||||
|
// perform CFG sampling
|
||||||
|
for ( auto i = 0; i < n_vocab * n_outputs; ++i ) {
|
||||||
|
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||||
// skip if not masked
|
// skip if not masked
|
||||||
if ( !is_masked[idx] )
|
if ( !is_masked[idx] ) {
|
||||||
|
scores[idx] = 1.0f;
|
||||||
continue;
|
continue;
|
||||||
// ensures only tokens within our designated range are used
|
}
|
||||||
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
|
|
||||||
auto* null_logit = &null_logits[idx];
|
auto* logit = &logits[idx * n_vocab];
|
||||||
|
auto* null_logit = &null_logits[idx * n_vocab];
|
||||||
|
|
||||||
// perform softmax before modifying logits
|
// perform softmax before modifying logits
|
||||||
std::vector<float> softmaxed = soft_max( n_vocab, logits );
|
std::vector<float> softmaxed = soft_max( n_vocab, logit );
|
||||||
|
|
||||||
// perform CFG sampling
|
|
||||||
for ( auto i = 0; i < n_vocab; ++i ) {
|
for ( auto i = 0; i < n_vocab; ++i ) {
|
||||||
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
|
logit[i] = null_logit[i] + (logit[i] - null_logit[i]) * cfg_strength;
|
||||||
}
|
}
|
||||||
|
|
||||||
// sample ith token
|
// sample ith token
|
||||||
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx );
|
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx );
|
||||||
// store token if it was masked
|
// store token if it was masked
|
||||||
|
@ -654,23 +664,34 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
|
||||||
}
|
}
|
||||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||||
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
|
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
|
||||||
output_tokens.reserve(n_outputs);
|
output_tokens.clear();
|
||||||
|
output_tokens.resize(n_outputs);
|
||||||
// do one step on many tokens
|
// do one step on many tokens
|
||||||
if ( llama_decode(ctx, batch) ) {
|
if ( llama_decode(ctx, batch) ) {
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
return output_tokens;
|
return output_tokens;
|
||||||
}
|
}
|
||||||
llama_kv_cache_clear(ctx); // necessary for many reasons
|
llama_kv_cache_clear(ctx); // necessary for many reasons
|
||||||
// to-do: figure out why all logits are the same for each token......
|
|
||||||
// "reverse" iterate from backwards indexing
|
auto sparams = llama_sampler_chain_default_params();
|
||||||
|
sparams.no_perf = false;
|
||||||
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(1));
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
|
||||||
|
|
||||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||||
// sample ith token
|
// sample ith token
|
||||||
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
|
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
|
||||||
|
|
||||||
// store token
|
// store token
|
||||||
output_tokens.emplace_back(t);
|
output_tokens[idx] = t;
|
||||||
}
|
}
|
||||||
if ( verbose ) print_tokens( output_tokens );
|
if ( verbose ) print_tokens( output_tokens );
|
||||||
|
|
||||||
|
llama_sampler_free(smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto t_main_end = ggml_time_us();
|
const auto t_main_end = ggml_time_us();
|
||||||
|
@ -738,15 +759,12 @@ int main( int argc, char** argv ) {
|
||||||
// initialize the sampler
|
// initialize the sampler
|
||||||
auto sparams = llama_sampler_chain_default_params();
|
auto sparams = llama_sampler_chain_default_params();
|
||||||
sparams.no_perf = false;
|
sparams.no_perf = false;
|
||||||
llama_sampler * smpl_ar = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
llama_sampler * smpl_nar = llama_sampler_chain_init(sparams);
|
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(0));
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(0));
|
||||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(1.0, 1));
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
|
||||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0));
|
llama_sampler_chain_add(smpl, llama_sampler_init_temp (1.0));
|
||||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
|
|
||||||
|
|
||||||
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
|
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
|
||||||
if (!ectx) {
|
if (!ectx) {
|
||||||
|
@ -790,10 +808,10 @@ int main( int argc, char** argv ) {
|
||||||
// NAR-len demasking
|
// NAR-len demasking
|
||||||
if ( modality == MODALITY_NAR_LEN ) {
|
if ( modality == MODALITY_NAR_LEN ) {
|
||||||
// inference len
|
// inference len
|
||||||
int len = 75;
|
int len = 0;
|
||||||
if ( !len ) {
|
if ( !len ) {
|
||||||
input.task = "len";
|
input.task = "len";
|
||||||
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN );
|
output_tokens = generate( ctx, model, smpl, input, io_map, 5, INFERENCE_MODE_LEN );
|
||||||
{
|
{
|
||||||
int digit = 1;
|
int digit = 1;
|
||||||
for (int i = output_tokens.size() - 1; i >= 0; i--) {
|
for (int i = output_tokens.size() - 1; i >= 0; i--) {
|
||||||
|
@ -815,7 +833,8 @@ int main( int argc, char** argv ) {
|
||||||
input.task = "tts";
|
input.task = "tts";
|
||||||
for ( auto l = 0; l < 8; ++l ) {
|
for ( auto l = 0; l < 8; ++l ) {
|
||||||
input.rvq_l = l;
|
input.rvq_l = l;
|
||||||
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
|
output_tokens = generate( ctx, model, smpl, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
|
||||||
|
if ( l == 0 ) input.resp.clear();
|
||||||
input.resp.emplace_back( output_tokens );
|
input.resp.emplace_back( output_tokens );
|
||||||
}
|
}
|
||||||
// AR+NAR
|
// AR+NAR
|
||||||
|
@ -823,7 +842,7 @@ int main( int argc, char** argv ) {
|
||||||
input.task = "tts";
|
input.task = "tts";
|
||||||
for ( auto l = 0; l < 8; ++l ) {
|
for ( auto l = 0; l < 8; ++l ) {
|
||||||
input.rvq_l = l;
|
input.rvq_l = l;
|
||||||
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
|
output_tokens = generate( ctx, model, smpl, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
|
||||||
input.resp.emplace_back( output_tokens );
|
input.resp.emplace_back( output_tokens );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -835,8 +854,7 @@ int main( int argc, char** argv ) {
|
||||||
// cleanup
|
// cleanup
|
||||||
encodec_free(ectx);
|
encodec_free(ectx);
|
||||||
|
|
||||||
llama_sampler_free(smpl_nar);
|
llama_sampler_free(smpl);
|
||||||
llama_sampler_free(smpl_ar);
|
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user