nvm fixed

This commit is contained in:
mrq 2024-12-23 22:23:43 -06:00
parent f62f99b8de
commit 532200de2a
3 changed files with 120 additions and 138 deletions

View File

@ -2,7 +2,7 @@
This is an implementation that makes use of [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [encodec.cpp](https://github.com/PABannier/encodec.cpp).
At the moment it's ***very*** barebones as I try and wrestle with `llama.cpp`'s API without needing to modify its code.
At the moment it's ***very*** work in progress.
## Build
@ -14,15 +14,14 @@ Run `make`.
### Required Modifications
[`encodec.cpp`](https://github.com/e-c-k-e-r/encodec.cpp) requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working.
[`encodec.cpp`](https://github.com/PABannier/encodec.cpp) requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working (per my [fork](https://github.com/e-c-k-e-r/encodec.cpp)).
[`llama.cpp`](https://github.com/e-c-k-e-r/llama.cpp) only possible modification needs to ensure that a non-causal attention mask is used; everything necessary can be hacked together with clever tricks.
[`llama.cpp`](https://github.com/ggerganov/llama.cpp) only possible modification needs to ensure that a non-causal attention mask is used; everything necessary can be hacked together with clever tricks.
## To-Do
* [x] converted model to GGUF
* [ ] convert it without modifying any of the existing code, as the tokenizer requires some care
* [ ] *actually* convert the model properly, as the embeddings differ from the real model
* [x] basic framework
* [x] load the quantized model
* [x] orchestrate the required embeddings

View File

@ -40,10 +40,9 @@ std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
size_t size = tensor->ne[0] * tensor->ne[1];
std::vector<float> res( size );
auto* qtype = ggml_get_type_traits(tensor->type);
// dequantize if needed
if ( ggml_is_quantized(tensor->type) ) {
qtype->to_float(tensor->data, res.data(), res.size());
auto* type_trait = ggml_get_type_traits(tensor->type);
if ( type_trait->to_float ) {
type_trait->to_float(tensor->data, res.data(), res.size());
} else {
memcpy( res.data(), tensor->data, res.size() * sizeof(float) );
}
@ -78,27 +77,16 @@ ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_context* ctx, struct ggml_te
ggml_tensor* res = ggml_view_2d( ctx, tensor, tensor->ne[0], end - start, tensor->nb[1], tensor->nb[1] * start );
/*
printf("%p: %i | %i | %i | %i || %p: %i | %i | %i | %i\n",
tensor->data, tensor->ne[0], tensor->ne[1], tensor->nb[1], tensor->nb[2],
res->data, res->ne[0], res->ne[1], res->nb[1], res->nb[2]
);
*/
return res;
}
struct ggml_tensor * VALL_E_API vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
return userdata.prom_embds[idx];
void VALL_E_API print_tokens( const std::vector<llama_token>& tokens, const std::string& prefix ) {
printf("%s[", prefix.c_str());
for ( auto i = 0; i < tokens.size(); ++i ) {
printf("%i%s", tokens[i], i + 1 < tokens.size() ? ", " : "");
}
printf("]\n");
}
struct ggml_tensor * VALL_E_API vall_e_get_resp_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
return userdata.resp_embds[idx];
}
struct ggml_tensor * VALL_E_API vall_e_get_aux_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
return userdata.aux_embds[idx];
}
const io_t& VALL_E_API vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) {
return io_map.io[name];
@ -140,43 +128,37 @@ void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
io_map.io[entry.name].head = entry.head_idx < 0 ? NULL : userdata.heads[entry.head_idx];
}
io_map.io["text"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 0));
io_map.io["rvq_l"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 1));
io_map.io["lang"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 2));
io_map.io["task"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 3));
io_map.io["len"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 4));
io_map.io["tone"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 5));
io_map.io["sep"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 6));
io_map.io["text"].embds = read_2d_tensor(userdata.aux_embds[0]);
io_map.io["rvq_l"].embds = read_2d_tensor(userdata.aux_embds[1]);
io_map.io["lang"].embds = read_2d_tensor(userdata.aux_embds[2]);
io_map.io["task"].embds = read_2d_tensor(userdata.aux_embds[3]);
io_map.io["len"].embds = read_2d_tensor(userdata.aux_embds[4]);
io_map.io["tone"].embds = read_2d_tensor(userdata.aux_embds[5]);
io_map.io["sep"].embds = read_2d_tensor(userdata.aux_embds[6]);
io_map.io["prom|0"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 0));
io_map.io["prom|1"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 1));
io_map.io["prom|2"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 2));
io_map.io["prom|3"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 3));
io_map.io["prom|4"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 4));
io_map.io["prom|5"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 5));
io_map.io["prom|6"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 6));
io_map.io["prom|7"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 7));
io_map.io["prom|0"].embds = read_2d_tensor(userdata.prom_embds[0]);
io_map.io["prom|1"].embds = read_2d_tensor(userdata.prom_embds[1]);
io_map.io["prom|2"].embds = read_2d_tensor(userdata.prom_embds[2]);
io_map.io["prom|3"].embds = read_2d_tensor(userdata.prom_embds[3]);
io_map.io["prom|4"].embds = read_2d_tensor(userdata.prom_embds[4]);
io_map.io["prom|5"].embds = read_2d_tensor(userdata.prom_embds[5]);
io_map.io["prom|6"].embds = read_2d_tensor(userdata.prom_embds[6]);
io_map.io["prom|7"].embds = read_2d_tensor(userdata.prom_embds[7]);
io_map.io["resps|AR:0:0"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 0));
io_map.io["resps|NAR:0:1"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 1));
io_map.io["resps|NAR:1:2"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 2));
io_map.io["resps|NAR:2:3"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 3));
io_map.io["resps|NAR:3:4"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 4));
io_map.io["resps|NAR:4:5"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 5));
io_map.io["resps|NAR:5:6"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 6));
io_map.io["resps|NAR:6:7"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 7));
io_map.io["resps|NAR:0:0"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 8));
for ( auto& entry : io_ranges ) {
for ( auto i = 0; i < 32; ++i ) printf("%s: %i: %f\n", entry.name.c_str(), i, io_map.io[entry.name].embds[i] );
}
io_map.io["resps|AR:0:0"].embds = read_2d_tensor(userdata.resp_embds[0]);
io_map.io["resps|NAR:0:1"].embds = read_2d_tensor(userdata.resp_embds[1]);
io_map.io["resps|NAR:1:2"].embds = read_2d_tensor(userdata.resp_embds[2]);
io_map.io["resps|NAR:2:3"].embds = read_2d_tensor(userdata.resp_embds[3]);
io_map.io["resps|NAR:3:4"].embds = read_2d_tensor(userdata.resp_embds[4]);
io_map.io["resps|NAR:4:5"].embds = read_2d_tensor(userdata.resp_embds[5]);
io_map.io["resps|NAR:5:6"].embds = read_2d_tensor(userdata.resp_embds[6]);
io_map.io["resps|NAR:6:7"].embds = read_2d_tensor(userdata.resp_embds[7]);
io_map.io["resps|NAR:0:0"].embds = read_2d_tensor(userdata.resp_embds[8]);
#else
auto* embds = llama_get_embedding_weights( model );
auto* heads = llama_get_output_head_tensor( model );
// prepare slices
// std::vector<float> raw_embeddings = read_2d_tensor( embds );
for ( auto& entry : io_ranges ) {
io_map.io[entry.name] = entry;
@ -184,16 +166,6 @@ void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
io_map.io[entry.name].n_vocab = entry.end - entry.start;
io_map.io[entry.name].embds = read_2d_tensor(view_2d_tensor( io_map.ctx, embds, entry.start, entry.end ));
io_map.io[entry.name].head = entry.head_idx < 0 ? NULL : view_2d_tensor( io_map.ctx, heads, entry.start, entry.end );
// these two differ after the first embedding and I don't know why.........
/*
auto raw_embd = std::vector<float>( raw_embeddings.data() + entry.start * n_embd, raw_embeddings.data() + entry.end * n_embd );
auto sliced_embd = read_2d_tensor( embd_tensor );
io_map.io[entry.name].embds = raw_embd;
for ( auto i = 0; i < 32; ++i ) printf("%s: %i: %f == %f \n", entry.name.c_str(), i, raw_embd[i], sliced_embd[i] );
*/
}
#endif
}
@ -228,9 +200,6 @@ void VALL_E_API batch_add( llama_batch& batch, llama_token id, int n_embd, const
for (size_t i = 0; i < seq_ids.size(); ++i) batch.seq_id[batch.n_tokens][i] = seq_ids[i];
batch.logits[batch.n_tokens] = output ? 1 : 0;
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, batch.pos[batch.n_tokens], &batch.embd[batch.n_tokens], batch.logits[batch.n_tokens] );
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, pos, embds, output );
batch.n_tokens++;
}
// reads a waveform from disk
@ -283,13 +252,13 @@ std::vector<std::vector<int32_t>> VALL_E_API encode_audio_from_disk( struct enco
std::vector<float> wavform;
if(!read_wav_from_disk(path, wavform)) {
printf("%s: error during reading wav file\n", __func__);
fprintf(stderr, "%s: error during reading wav file\n", __func__);
return {};
}
// compress audio
if (!encodec_compress_audio(ectx, wavform.data(), wavform.size(), 1)) {
printf("%s: error during compression \n", __func__);
fprintf(stderr, "%s: error during compression \n", __func__);
return {};
}
@ -325,7 +294,7 @@ std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const
// decompress audio
if (!encodec_decompress_audio(ectx, codes.data(), codes.size(), 1)) {
printf("%s: error during decompression\n", __func__);
fprintf(stderr, "%s: error during decompression\n", __func__);
return {};
}
@ -366,10 +335,27 @@ std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
std::vector<float> res( n_logits, 0.0f );
std::vector<float> expd( n_logits, 0.0f );
float denom = 0.0f;
for ( auto i = 0; i < n_logits; ++i ) {
denom += exp( logits[i] );
expd[i] = exp( logits[i] );
denom += expd[i];
}
// to-do: assert denom != 0.0f
for ( auto i = 0; i < n_logits; ++i ) {
res[i] = expd[i] / denom;
}
return res;
}
std::vector<float> VALL_E_API log_soft_max( int n_logits, const float* logits ) {
std::vector<float> res( n_logits, 0.0f );
float denom = 0.0f;
for ( auto i = 0; i < n_logits; ++i ) {
denom += logits[i];
}
// to-do: assert denom != 0.0f
for ( auto i = 0; i < n_logits; ++i ) {
@ -503,7 +489,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
int32_t n_vocab = io.n_vocab;
llama_token stop_token = io.end - io.start - 1;
printf("Generating in %s (%i) mode (%i:%i) (%i)\n", embd_name.c_str(), io.head_idx, io.start, io.end, stop_token);
if ( verbose ) printf("Generating in %s (%i) mode (%i:%i) (%i)\n", embd_name.c_str(), io.head_idx, io.start, io.end, stop_token);
// update model's output heads / causal mode
llama_set_output_head( model, io.head );
@ -515,15 +501,12 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
// if INFERENCE_MODE_AR || INFERENCE_MODE_LEN
if ( causal ) {
output_tokens.reserve(max_tokens);
if ( verbose ) {
printf("[");
fflush(stdout);
}
while ( output_tokens.size() < max_tokens ) {
if ( llama_decode(ctx, batch) ) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return output_tokens;
}
llama_kv_cache_clear(ctx); // necessary for many reasons
// sample token
auto t = llama_sampler_sample(smpl, ctx, -1);
@ -537,21 +520,15 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
output_tokens.emplace_back(t);
// update batch with token
batch_add( batch, t, io_map.n_embd, embds, output_tokens.size(), true );
if ( verbose ) {
printf("%i, ", t);
fflush(stdout);
}
}
if ( verbose ) {
printf("]\n");
fflush(stdout);
if ( verbose ) print_tokens( output_tokens );
}
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
const llama_token MASK_TOKEN = 1024; // token value for masking
const float PI = 3.141592653589793f;
// to-do: derive from sampling arguments
int32_t steps = 30; // number of demasking steps
int32_t steps = 10; // number of demasking steps
int32_t seq_len = n_outputs;
float temperature = 1.5f;
float cfg_strength = 2.5f;
@ -572,34 +549,40 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
// do one step on many tokens
for ( auto step = 0; step < steps; ++step ) {
if ( verbose ) {
printf("[%i/%i] [", step, steps);
fflush(stdout);
}
float timestep = (step+1) / steps; // to-do: align with torch.linspace
float timestep = ((float)step) / steps; // to-do: align with torch.linspace
float annealing = 1.0f - timestep;
float noise_p = cos( timestep * PI * 0.5f );
float remask_p = 0.5f / steps;
int32_t n_masked_tokens = std::min(int(noise_p * seq_len), 1);
float sampling_temperature = temperature * annealing;
float sampling_cfg_strength = timestep * cfg_strength;
std::vector<bool> is_masked(n_outputs, false);
std::vector<int32_t> masked_indices;
masked_indices.reserve(n_masked_tokens);
std::vector<float> sorted = scores;
std::sort(sorted.begin(), sorted.end());
masked_indices.insert( masked_indices.end(), sorted.begin(), sorted.begin() + n_masked_tokens );
float noise_p = cos( timestep * PI * 0.5f );
float remask_p = 0.0f; // 0.5f / steps;
int32_t n_masked_tokens = (noise_p + remask_p) * seq_len;
if ( n_masked_tokens < 1 ) {
n_masked_tokens = 1;
}
if ( n_masked_tokens > n_outputs ) {
n_masked_tokens = n_outputs;
}
// masked mask
std::vector<bool> is_masked(n_outputs, false);
// sort previous scores
std::vector<score_t> sorted_scores( n_outputs );
for ( auto i = 0; i < n_outputs; ++i ) sorted_scores[i] = { i, scores[i] };
std::sort(sorted_scores.begin(), sorted_scores.end());
// and top-k pick the worst scores
for ( auto i = 0; i < n_masked_tokens; ++i ) {
auto idx = sorted_scores[i].idx;
// mask off tokens
for ( auto& idx : masked_indices ) {
output_tokens[idx] = MASK_TOKEN;
is_masked[idx] = true;
}
// update token mask
for ( auto i = 0; i < n_outputs; ++i ) {
is_masked[i] = output_tokens[i] == MASK_TOKEN;
}
if ( verbose ) print_tokens( output_tokens, "Masked tokens:" );
// update batch
// to-do: only update the embeddings instead
@ -611,15 +594,14 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
null_batch.n_tokens = 0;
fill_batch( null_batch, input, io_map, mode );
// to-do: update sampling temperature
// cfg decode
if ( llama_decode(ctx, null_batch) ) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return output_tokens;
}
llama_kv_cache_clear(ctx); // necessary for many reasons
// copy null probabilities
std::vector<float> null_logits(n_outputs * n_vocab, -INFINITY);
std::vector<float> null_logits(n_outputs * n_vocab, 0.0f);
// to-do: copy once
for ( auto idx = 0; idx < n_outputs; ++idx ) {
memcpy( &null_logits[idx * n_vocab], llama_get_logits_ith( ctx, null_batch.n_tokens - n_outputs + idx ), sizeof(float) * n_vocab );
@ -630,6 +612,17 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return output_tokens;
}
llama_kv_cache_clear(ctx); // necessary for many reasons
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(0));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(1.0, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (sampling_temperature));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1130));
// to-do: figure out why all logits are the same for each token......
// "reverse" iterate from backwards indexing
for ( auto idx = 0; idx < n_outputs; ++idx ) {
@ -652,16 +645,12 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
// store token if it was masked
output_tokens[idx] = t;
// update score if it was masked
scores[idx] = 1.0f - softmaxed[t]; // invert so we pick the worst tokens later
if ( verbose ) {
printf("%i, ", t);
fflush(stdout);
}
}
if ( verbose ) {
printf("\n");
fflush(stdout);
scores[idx] = softmaxed[t]; // invert so we pick the worst tokens later
}
llama_sampler_free(smpl);
if ( verbose ) print_tokens( output_tokens );
}
} else if ( mode == INFERENCE_MODE_NAR ) {
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
@ -671,27 +660,17 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return output_tokens;
}
llama_kv_cache_clear(ctx); // necessary for many reasons
// to-do: figure out why all logits are the same for each token......
// "reverse" iterate from backwards indexing
if ( verbose ) {
printf("[");
fflush(stdout);
}
for ( auto idx = 0; idx < n_outputs; ++idx ) {
// sample ith token
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
// store token
output_tokens.emplace_back(t);
if ( verbose ) {
printf("%i, ", t);
fflush(stdout);
}
}
if ( verbose ) {
printf("]\n");
fflush(stdout);
}
if ( verbose ) print_tokens( output_tokens );
}
const auto t_main_end = ggml_time_us();
@ -771,7 +750,7 @@ int main( int argc, char** argv ) {
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
if (!ectx) {
printf("%s: error during loading model\n", __func__);
fprintf(stderr, "%s: error during loading model\n", __func__);
return 1;
}
@ -811,7 +790,7 @@ int main( int argc, char** argv ) {
// NAR-len demasking
if ( modality == MODALITY_NAR_LEN ) {
// inference len
int len = 0;
int len = 75;
if ( !len ) {
input.task = "len";
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN );

View File

@ -101,10 +101,18 @@ struct io_map_t {
ggml_context* ctx = NULL;
};
struct score_t {
int32_t idx;
float value;
bool operator<( const score_t& that ) const { return this->value < that.value; }
};
// helper tensor functions
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor );
ggml_tensor* VALL_E_API view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket
ggml_tensor* VALL_E_API view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 );
void VALL_E_API print_tokens( const std::vector<llama_token>& tokens, const std::string& prefix = "Tokens: " );
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<llama_token>& tokens, int n_embd, const float* embds );
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
@ -125,8 +133,4 @@ std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const
const io_t& VALL_E_API vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name );
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name );
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name );
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );
struct ggml_tensor * VALL_E_API vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx );
struct ggml_tensor * VALL_E_API vall_e_get_resp_embds( llama_vall_e_userdata& userdata, int32_t idx );
struct ggml_tensor * VALL_E_API vall_e_get_aux_embds( llama_vall_e_userdata& userdata, int32_t idx );
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );