From f62f99b8de4b90565de8b36567113d8ed5f75b4a Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 23 Dec 2024 20:36:40 -0600 Subject: [PATCH] more work on vall_e.cpp (need to resolve why the embeddings (and maybe the weights as a whole) are different from the base model) --- vall_e.cpp/README.md | 6 +- vall_e.cpp/include/llama_hack.h | 163 ++++++++++++++++ vall_e.cpp/vall_e.cpp | 320 +++++++++++++++++--------------- vall_e.cpp/vall_e.h | 47 ++--- 4 files changed, 361 insertions(+), 175 deletions(-) create mode 100644 vall_e.cpp/include/llama_hack.h diff --git a/vall_e.cpp/README.md b/vall_e.cpp/README.md index 673d5ea..916df18 100644 --- a/vall_e.cpp/README.md +++ b/vall_e.cpp/README.md @@ -16,12 +16,13 @@ Run `make`. [`encodec.cpp`](https://github.com/e-c-k-e-r/encodec.cpp) requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working. -[`llama.cpp`](https://github.com/e-c-k-e-r/llama.cpp) *might* not require any modifications, but implementing `LLM_ARCH_VALL_E` requires some surgery. +[`llama.cpp`](https://github.com/e-c-k-e-r/llama.cpp) only possible modification needs to ensure that a non-causal attention mask is used; everything necessary can be hacked together with clever tricks. ## To-Do * [x] converted model to GGUF * [ ] convert it without modifying any of the existing code, as the tokenizer requires some care + * [ ] *actually* convert the model properly, as the embeddings differ from the real model * [x] basic framework * [x] load the quantized model * [x] orchestrate the required embeddings @@ -35,13 +36,10 @@ Run `make`. * [x] sum embeddings for the `prom` and prior `resp`s * [ ] working `AR` output * [x] `AR` sampling - * currently need a model that didn't regress with the `AR:0:0` output * [ ] working `NAR-len` output * [x] `NAR-len` sampling - * need to assert that a non-causal mask is used * [ ] working `NAR` output * [x] `NAR` sampling - * need to assert that a non-causal mask is used * [x] decode audio to disk * [ ] a functional CLI * [ ] actually make it work diff --git a/vall_e.cpp/include/llama_hack.h b/vall_e.cpp/include/llama_hack.h new file mode 100644 index 0000000..e1a07cd --- /dev/null +++ b/vall_e.cpp/include/llama_hack.h @@ -0,0 +1,163 @@ +#pragma once + +#include "llama-vocab.h" +#include + +/* Begin cringe so I can access the model's tok_embd */ +// it needs to be copied so the struct layout is exactly as it is under llama.cpp +#define LLAMA_MAX_LAYERS 512 +#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 + +enum e_model { + MODEL_UNKNOWN, +}; + +enum llm_arch { + LLM_ARCH_UNKNOWN, +}; + +struct llama_hparams_posnet { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams_convnext { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams { + bool vocab_only; + bool rope_finetuned; + bool use_par_res; + bool swin_norm; + + uint32_t n_vocab = 0; + uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_embd; + uint32_t n_embd_features = 0; + uint32_t n_layer; + uint32_t n_rot; + uint32_t n_swa = 0; // sliding window attention (SWA) + uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; + uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_rel_attn_bkts = 0; + + // for WavTokenizer + struct llama_hparams_posnet posnet; + struct llama_hparams_convnext convnext; + + std::array n_head_arr; + std::array n_head_kv_arr; + std::array n_ff_arr; + + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; + uint32_t n_expert_shared = 0; + float expert_weights_scale = 0.0; + + float f_norm_eps; + float f_norm_rms_eps; + float f_norm_group_eps; + + uint32_t n_norm_groups; + + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; + + // for RWKV + uint32_t rescale_every_n_layers = 0; + uint32_t time_mix_extra_dim = 0; + uint32_t time_decay_extra_dim = 0; + uint32_t wkv_head_size = 0; + + float rope_attn_factor = 1.0f; + float rope_freq_base_train; + float rope_freq_scale_train; + uint32_t n_ctx_orig_yarn; + float rope_yarn_log_mul; + int rope_sections[4]; + + // for State Space Models + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + bool ssm_dt_b_c_rms = false; + + float f_clamp_kqv = 0.0f; + float f_max_alibi_bias = 0.0f; + float f_logit_scale = 0.0f; + + // Additional scale factors (Granite/Granite MoE) + float f_residual_scale = 0.0f; + float f_embedding_scale = 0.0f; + float f_attention_scale = 0.0f; + + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; + + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = LLAMA_TOKEN_NULL; + + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; + enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; + enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; +}; + +struct llama_model { + e_model type = MODEL_UNKNOWN; + llm_arch arch = LLM_ARCH_UNKNOWN; + llama_ftype ftype = LLAMA_FTYPE_ALL_F32; + + std::string name = "n/a"; + + llama_hparams hparams = {}; + llama_vocab vocab; + + struct ggml_tensor * tok_embd = nullptr; + struct ggml_tensor * type_embd = nullptr; + struct ggml_tensor * pos_embd = nullptr; + struct ggml_tensor * tok_norm = nullptr; + struct ggml_tensor * tok_norm_b = nullptr; + + struct ggml_tensor * output_norm = nullptr; + struct ggml_tensor * output_norm_b = nullptr; + struct ggml_tensor * output = nullptr; + struct ggml_tensor * output_b = nullptr; + struct ggml_tensor * output_norm_enc = nullptr; + + // classifier + struct ggml_tensor * cls = nullptr; + struct ggml_tensor * cls_b = nullptr; + struct ggml_tensor * cls_out = nullptr; + struct ggml_tensor * cls_out_b = nullptr; + + struct ggml_tensor * conv1d = nullptr; + struct ggml_tensor * conv1d_b = nullptr; +}; + +/* BEGIN VALL-E SPECIFIC HELPERS */ +struct ggml_tensor * llama_get_embedding_weights(struct llama_model * model) { + return model->tok_embd; +} +struct ggml_tensor * llama_get_output_head_tensor(struct llama_model * model ) { + return model->output; +} +void llama_set_output_head(struct llama_model * model, struct ggml_tensor* tensor ) { + // set the output tensor + model->output = tensor; + // required to properly output logits + *const_cast(&model->hparams.n_vocab) = tensor->ne[1]; +} +/* END VALL-E SPECIFIC HELPERS */ + +/* End cringe code */ \ No newline at end of file diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 6d752c1..0fbdd0f 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -7,7 +7,7 @@ #include #include -ranges_t io_ranges[] = { +io_t io_ranges[] = { { "text", 0, 256, 9, }, { "rvq_l", 256, 264, -1, }, { "lang", 264, 270, -1, }, @@ -25,15 +25,15 @@ ranges_t io_ranges[] = { { "prom|6", 6436, 7460, -1, }, { "prom|7", 7460, 8484, -1, }, - { "resps|AR:0 8484, 9509, 0,:0", }, - { "resps|NAR:0 9509, 10533, 1,:1", }, - { "resps|NAR:1: 10533, 11557, 2,2", }, - { "resps|NAR:2: 11557, 12581, 3,3", }, - { "resps|NAR:3: 12581, 13605, 4,4", }, - { "resps|NAR:4: 13605, 14629, 5,5", }, - { "resps|NAR:5: 14629, 15653, 6,6", }, - { "resps|NAR:6: 15653, 16677, 7,7", }, - { "resps|NAR:0: 16677, 17702, 8,0", }, + { "resps|AR:0:0", 8484, 9509, 0 }, + { "resps|NAR:0:1", 9509, 10533, 1 }, + { "resps|NAR:1:2", 10533, 11557, 2 }, + { "resps|NAR:2:3", 11557, 12581, 3 }, + { "resps|NAR:3:4", 12581, 13605, 4 }, + { "resps|NAR:4:5", 13605, 14629, 5 }, + { "resps|NAR:5:6", 14629, 15653, 6 }, + { "resps|NAR:6:7", 15653, 16677, 7 }, + { "resps|NAR:0:0", 16677, 17702, 8 }, }; std::vector VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) { @@ -51,6 +51,43 @@ std::vector VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) { return res; } +ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) { + // to-do: implement other dim + if ( start < 0 ) start = tensor->ne[1] + start; + if ( end < 0 ) end = tensor->ne[1] + end; + + ggml_tensor* res = new ggml_tensor(); + memcpy( res, tensor, sizeof(ggml_tensor) ); + + res->op = GGML_OP_VIEW; + res->src[0] = tensor; + + res->data += res->nb[1] * start; + res->ne[1] = end - start; + + for (int i = 2; i < GGML_MAX_DIMS; i++) { + res->nb[i] = res->nb[i - 1] * res->ne[i - 1]; + } + + return res; +} +ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_context* ctx, struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) { + // to-do: implement other dim + if ( start < 0 ) start = tensor->ne[1] + start; + if ( end < 0 ) end = tensor->ne[1] + end; + + ggml_tensor* res = ggml_view_2d( ctx, tensor, tensor->ne[0], end - start, tensor->nb[1], tensor->nb[1] * start ); + + /* + printf("%p: %i | %i | %i | %i || %p: %i | %i | %i | %i\n", + tensor->data, tensor->ne[0], tensor->ne[1], tensor->nb[1], tensor->nb[2], + res->data, res->ne[0], res->ne[1], res->nb[1], res->nb[2] + ); + */ + + return res; +} + struct ggml_tensor * VALL_E_API vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx ) { return userdata.prom_embds[idx]; @@ -63,81 +100,100 @@ struct ggml_tensor * VALL_E_API vall_e_get_aux_embds( llama_vall_e_userdata& us } -const embeddings_t& VALL_E_API vall_e_inputs_map_get_embeddings( inputs_map_t& inputs_map, const std::string& name ) { - return inputs_map.embds[name]; +const io_t& VALL_E_API vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) { + return io_map.io[name]; } -const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( inputs_map_t& inputs_map, const std::string& name ) { - return inputs_map.embds[name].embds.data(); +const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& io_map, const std::string& name ) { + return io_map.io[name].embds.data(); } -int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( inputs_map_t& inputs_map, const std::string& name ) { - return inputs_map.embds[name].range.classifier_idx; +int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& io_map, const std::string& name ) { + return io_map.io[name].head_idx; } -void VALL_E_API vall_e_inputs_map_init( inputs_map_t& inputs_map, llama_model* model ) { +void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) { auto n_embd = llama_n_embd( model ); auto n_vocab = llama_n_vocab( model ); - inputs_map.n_embd = n_embd; - inputs_map.n_vocab = n_vocab; + io_map.n_embd = n_embd; + io_map.n_vocab = n_vocab; - auto& userdata = *llama_get_vall_e_userdata( model ); + int32_t ctx_size = 24 * 2 * ggml_tensor_overhead(); // 24 embeddings + 24 output heads (generous) (should only really need to do this for output heads since we manually handle embeddings) + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + io_map.ctx = ggml_init(params); // to-do: figure a nicer way to do this #if LLAMA_CPP_USE_VALL_E_ARCH - inputs_map.embds["text"] = { n_embd, 0, { "text", 0, 0, 9, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 0)) }; - inputs_map.embds["rvq_l"] = { n_embd, 0, { "rvq_l", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 1)) }; - inputs_map.embds["lang"] = { n_embd, 0, { "lang", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 2)) }; - inputs_map.embds["task"] = { n_embd, 0, { "task", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 3)) }; - inputs_map.embds["len"] = { n_embd, 0, { "len", 0, 0, 10, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 4)) }; - inputs_map.embds["tone"] = { n_embd, 0, { "tone", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 5)) }; - inputs_map.embds["sep"] = { n_embd, 0, { "sep", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 6)) }; + auto& userdata = *llama_get_vall_e_userdata( model ); - inputs_map.embds["prom|0"] = { n_embd, 0, { "prom|0", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 0)) }; - inputs_map.embds["prom|1"] = { n_embd, 0, { "prom|1", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 1)) }; - inputs_map.embds["prom|2"] = { n_embd, 0, { "prom|2", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 2)) }; - inputs_map.embds["prom|3"] = { n_embd, 0, { "prom|3", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 3)) }; - inputs_map.embds["prom|4"] = { n_embd, 0, { "prom|4", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 4)) }; - inputs_map.embds["prom|5"] = { n_embd, 0, { "prom|5", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 5)) }; - inputs_map.embds["prom|6"] = { n_embd, 0, { "prom|6", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 6)) }; - inputs_map.embds["prom|7"] = { n_embd, 0, { "prom|7", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 7)) }; + for ( auto& entry : io_ranges ) { + io_map.io[entry.name] = entry; + + io_map.io[entry.name].n_embd = n_embd; + io_map.io[entry.name].n_vocab = entry.end - entry.start; + io_map.io[entry.name].start = 0; + io_map.io[entry.name].end = 0; + io_map.io[entry.name].head = entry.head_idx < 0 ? NULL : userdata.heads[entry.head_idx]; + } + + io_map.io["text"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 0)); + io_map.io["rvq_l"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 1)); + io_map.io["lang"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 2)); + io_map.io["task"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 3)); + io_map.io["len"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 4)); + io_map.io["tone"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 5)); + io_map.io["sep"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 6)); + + io_map.io["prom|0"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 0)); + io_map.io["prom|1"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 1)); + io_map.io["prom|2"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 2)); + io_map.io["prom|3"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 3)); + io_map.io["prom|4"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 4)); + io_map.io["prom|5"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 5)); + io_map.io["prom|6"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 6)); + io_map.io["prom|7"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 7)); - inputs_map.embds["resps|AR:0:0"] = { n_embd, 0, { "resps|AR:0:0", 0, 0, 0, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 0)) }; - inputs_map.embds["resps|NAR:0:1"] = { n_embd, 0, { "resps|NAR:0:1", 0, 0, 1, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 1)) }; - inputs_map.embds["resps|NAR:1:2"] = { n_embd, 0, { "resps|NAR:1:2", 0, 0, 2, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 2)) }; - inputs_map.embds["resps|NAR:2:3"] = { n_embd, 0, { "resps|NAR:2:3", 0, 0, 3, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 3)) }; - inputs_map.embds["resps|NAR:3:4"] = { n_embd, 0, { "resps|NAR:3:4", 0, 0, 4, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 4)) }; - inputs_map.embds["resps|NAR:4:5"] = { n_embd, 0, { "resps|NAR:4:5", 0, 0, 5, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 5)) }; - inputs_map.embds["resps|NAR:5:6"] = { n_embd, 0, { "resps|NAR:5:6", 0, 0, 6, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 6)) }; - inputs_map.embds["resps|NAR:6:7"] = { n_embd, 0, { "resps|NAR:6:7", 0, 0, 7, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 7)) }; - inputs_map.embds["resps|NAR:0:0"] = { n_embd, 0, { "resps|NAR:0:0", 0, 0, 8, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 8)) }; + io_map.io["resps|AR:0:0"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 0)); + io_map.io["resps|NAR:0:1"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 1)); + io_map.io["resps|NAR:1:2"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 2)); + io_map.io["resps|NAR:2:3"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 3)); + io_map.io["resps|NAR:3:4"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 4)); + io_map.io["resps|NAR:4:5"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 5)); + io_map.io["resps|NAR:5:6"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 6)); + io_map.io["resps|NAR:6:7"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 7)); + io_map.io["resps|NAR:0:0"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 8)); - // update values - for ( auto& pair : inputs_map.embds ) { - auto& k = pair.first; - auto& v = pair.second; - auto& embds = v.embds; - v.n_vocab = embds.size() / n_embd; - v.range.end = v.n_vocab; + for ( auto& entry : io_ranges ) { + for ( auto i = 0; i < 32; ++i ) printf("%s: %i: %f\n", entry.name.c_str(), i, io_map.io[entry.name].embds[i] ); } #else - -#if LLAMA_CPP_EXTENDED - auto* tensor = llama_get_embedding_weights( model ); -#else - auto* tensor = model->tok_embd; -#endif + auto* embds = llama_get_embedding_weights( model ); + auto* heads = llama_get_output_head_tensor( model ); // prepare slices - std::vector raw_embeddings = read_2d_tensor( tensor ); - for ( auto& range : io_ranges ) { - inputs_map.embds[range.name] = { - n_embd, - range.end - range.start, - range, - std::vector( raw_embeddings.data() + range.start, raw_embeddings.data() + range.end ) - }; + // std::vector raw_embeddings = read_2d_tensor( embds ); + for ( auto& entry : io_ranges ) { + io_map.io[entry.name] = entry; + + io_map.io[entry.name].n_embd = n_embd; + io_map.io[entry.name].n_vocab = entry.end - entry.start; + io_map.io[entry.name].embds = read_2d_tensor(view_2d_tensor( io_map.ctx, embds, entry.start, entry.end )); + io_map.io[entry.name].head = entry.head_idx < 0 ? NULL : view_2d_tensor( io_map.ctx, heads, entry.start, entry.end ); + + // these two differ after the first embedding and I don't know why......... + /* + auto raw_embd = std::vector( raw_embeddings.data() + entry.start * n_embd, raw_embeddings.data() + entry.end * n_embd ); + auto sliced_embd = read_2d_tensor( embd_tensor ); + + io_map.io[entry.name].embds = raw_embd; + + for ( auto i = 0; i < 32; ++i ) printf("%s: %i: %f == %f \n", entry.name.c_str(), i, raw_embd[i], sliced_embd[i] ); + */ } #endif } @@ -323,38 +379,38 @@ std::vector VALL_E_API soft_max( int n_logits, const float* logits ) { return res; } -void VALL_E_API fill_batch( llama_batch& batch, input_t& input, inputs_map_t& inputs_map, int mode ) { +void VALL_E_API fill_batch( llama_batch& batch, input_t& input, io_map_t& io_map, int mode ) { // keeps track of the position for each sequence size_t pos = 0; - auto n_embd = inputs_map.n_embd; + auto n_embd = io_map.n_embd; - const float* text_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "text"); - const float* rvq_l_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "rvq_l"); - const float* lang_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "lang"); - const float* task_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "task"); - const float* len_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "len"); - const float* tone_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "tone"); - const float* sep_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "sep"); + const float* text_embds = vall_e_inputs_map_get_embeddings_p(io_map, "text"); + const float* rvq_l_embds = vall_e_inputs_map_get_embeddings_p(io_map, "rvq_l"); + const float* lang_embds = vall_e_inputs_map_get_embeddings_p(io_map, "lang"); + const float* task_embds = vall_e_inputs_map_get_embeddings_p(io_map, "task"); + const float* len_embds = vall_e_inputs_map_get_embeddings_p(io_map, "len"); + const float* tone_embds = vall_e_inputs_map_get_embeddings_p(io_map, "tone"); + const float* sep_embds = vall_e_inputs_map_get_embeddings_p(io_map, "sep"); const float* prom_embds[] = { - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|0"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|1"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|2"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|3"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|4"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|5"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|6"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|7"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|0"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|1"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|2"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|3"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|4"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|5"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|6"), + vall_e_inputs_map_get_embeddings_p(io_map, "prom|7"), }; const float* resp_embds[] = { - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|AR:0:0"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:0:1"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:1:2"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:2:3"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:3:4"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:4:5"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:5:6"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:6:7"), - vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:0:0"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|AR:0:0"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:0:1"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:1:2"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:2:3"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:3:4"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:4:5"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:5:6"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:6:7"), + vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:0:0"), }; // insert text tokens @@ -394,20 +450,13 @@ void VALL_E_API fill_batch( llama_batch& batch, input_t& input, inputs_map_t& in } // generation code, should handle all modalities easily -std::vector VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, inputs_map_t& inputs_map, int max_tokens, int mode, bool verbose ) { - int rvq_l = input.rvq_l; - llama_token stop_token = -1; - int n_decode = 0; // number of tokens decoded - int n_outputs = 0; // number of output tokens to expect - int n_vocab = 0; - int n_embd = 0; +std::vector VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, io_map_t& io_map, int max_tokens, int mode, bool verbose ) { bool causal = true; // sample autoregressively or not - const float* embds = NULL; // embeddings to map output tokens through - ranges_t range; // I/O range + int n_outputs = 0; // number of output tokens to expect // create batch (targetting embeddings instead of tokens) - llama_batch batch = llama_batch_init( CTX_SIZE, inputs_map.n_embd, CTX_SIZE ); - fill_batch( batch, input, inputs_map, mode ); + llama_batch batch = llama_batch_init( CTX_SIZE, io_map.n_embd, CTX_SIZE ); + fill_batch( batch, input, io_map, mode ); // determine how many outputs we need for ( auto i = 0; i < batch.n_tokens; ++i ) { @@ -438,7 +487,7 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m "resps|NAR:5:6", "resps|NAR:6:7", }; - embd_name = k_embds[rvq_l]; + embd_name = k_embds[input.rvq_l]; // duration inferencing mode } else if ( mode == INFERENCE_MODE_LEN ) { embd_name = "len"; @@ -447,22 +496,18 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m embd_name = "resps|NAR:0:0"; } - auto& embeddings = vall_e_inputs_map_get_embeddings(inputs_map, embd_name); - range = embeddings.range; - embds = embeddings.embds.data(); - n_embd = embeddings.n_embd; - n_vocab = embeddings.n_vocab; - stop_token = range.end - range.start - 1; + auto& io = vall_e_inputs_map_get(io_map, embd_name); + const float* embds = io.embds.data(); - printf("Generating in %s (%i) mode (%i:%i) (%i)\n", embd_name.c_str(), range.classifier_idx, range.start, range.end, stop_token); + int32_t n_embd = io.n_embd; + int32_t n_vocab = io.n_vocab; + llama_token stop_token = io.end - io.start - 1; + + printf("Generating in %s (%i) mode (%i:%i) (%i)\n", embd_name.c_str(), io.head_idx, io.start, io.end, stop_token); // update model's output heads / causal mode -#if LLAMA_CPP_USE_VALL_E_ARCH - auto& userdata = *llama_get_vall_e_userdata( model ); - llama_set_output_head( model, userdata.heads[range.classifier_idx] ); -#endif - llama_set_causal_attn( ctx, causal ); - // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0]) + llama_set_output_head( model, io.head ); + llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0]) std::vector output_tokens; const auto t_main_start = ggml_time_us(); @@ -480,13 +525,6 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m return output_tokens; } - // ensures only tokens within our designated range are used - #if !LLAMA_CPP_USE_VALL_E_ARCH - auto* logits = llama_get_logits_ith( ctx, -1 ); - for ( auto i = 0; i < inputs_map.n_vocab; ++i ) { - if ( i < range.start || i >= range.end ) logits[i] = -INFINITY; - } - #endif // sample token auto t = llama_sampler_sample(smpl, ctx, -1); @@ -498,7 +536,7 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m // store token output_tokens.emplace_back(t); // update batch with token - batch_add( batch, t, inputs_map.n_embd, embds, output_tokens.size(), true ); + batch_add( batch, t, io_map.n_embd, embds, output_tokens.size(), true ); if ( verbose ) { printf("%i, ", t); fflush(stdout); @@ -527,7 +565,7 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m null_input.phn = {1, 2}; // null_input.resp.resize(1); - llama_batch null_batch = llama_batch_init( CTX_SIZE, inputs_map.n_embd, CTX_SIZE ); + llama_batch null_batch = llama_batch_init( CTX_SIZE, io_map.n_embd, CTX_SIZE ); // token scores to reference for masking std::vector scores(n_outputs, 1.0); @@ -567,11 +605,11 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m // to-do: only update the embeddings instead batch.n_tokens = 0; input.resp[0] = output_tokens; - fill_batch( batch, input, inputs_map, mode ); + fill_batch( batch, input, io_map, mode ); // update null batch null_input.resp[0] = output_tokens; null_batch.n_tokens = 0; - fill_batch( null_batch, input, inputs_map, mode ); + fill_batch( null_batch, input, io_map, mode ); // to-do: update sampling temperature @@ -602,11 +640,6 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx ); auto* null_logit = &null_logits[idx]; - #if !LLAMA_CPP_USE_VALL_E_ARCH - for ( auto i = 0; i < inputs_map.n_vocab; ++i ) { - if ( i < range.start || i >= range.end ) logits[i] = -INFINITY; - } - #endif // perform softmax before modifying logits std::vector softmaxed = soft_max( n_vocab, logits ); @@ -645,14 +678,6 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m fflush(stdout); } for ( auto idx = 0; idx < n_outputs; ++idx ) { - // ensures only tokens within our designated range are used - #if !LLAMA_CPP_USE_VALL_E_ARCH - auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx ); - for ( auto i = 0; i < inputs_map.n_vocab; ++i ) { - if ( i < range.start || i >= range.end ) logits[i] = -INFINITY; - } - #endif - // sample ith token auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx); @@ -674,7 +699,7 @@ std::vector VALL_E_API generate( llama_context* ctx, llama_model* m if ( verbose ) { printf("\n"); fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", - __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + __func__, output_tokens.size(), (t_main_end - t_main_start) / 1000000.0f, output_tokens.size() / ((t_main_end - t_main_start) / 1000000.0f)); fprintf(stderr, "\n"); llama_perf_sampler_print(smpl); @@ -692,7 +717,7 @@ int main( int argc, char** argv ) { int32_t ngl = 0; int modality = MODALITY_NAR_LEN; input_t input{}; - inputs_map_t inputs_map{}; + io_map_t io_map{}; // input.phonemes = "hˈɛloː ʋˈɔrlt"; input.phn = {1,22,111,100,4,37,115,169,11,2}; // hˈɛloː ʋˈɔrlt @@ -725,7 +750,6 @@ int main( int argc, char** argv ) { ctx_params.no_perf = false; ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; - // create two contexts, one's that causally, the other that isn't, because pain llama_context* ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -765,7 +789,7 @@ int main( int argc, char** argv ) { auto n_vocab = llama_n_vocab( model ); // grab input embeddings - vall_e_inputs_map_init( inputs_map, model ); + vall_e_inputs_map_init( io_map, model ); // tokenize phonemes // to-do: make this work, the vocab does not work @@ -787,10 +811,10 @@ int main( int argc, char** argv ) { // NAR-len demasking if ( modality == MODALITY_NAR_LEN ) { // inference len - int len = 290; + int len = 0; if ( !len ) { input.task = "len"; - output_tokens = generate( ctx, model, smpl_nar, input, inputs_map, 5, INFERENCE_MODE_LEN ); + output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN ); { int digit = 1; for (int i = output_tokens.size() - 1; i >= 0; i--) { @@ -812,7 +836,7 @@ int main( int argc, char** argv ) { input.task = "tts"; for ( auto l = 0; l < 8; ++l ) { input.rvq_l = l; - output_tokens = generate( ctx, model, smpl_nar, input, inputs_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR ); + output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR ); input.resp.emplace_back( output_tokens ); } // AR+NAR @@ -820,7 +844,7 @@ int main( int argc, char** argv ) { input.task = "tts"; for ( auto l = 0; l < 8; ++l ) { input.rvq_l = l; - output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, inputs_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR ); + output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR ); input.resp.emplace_back( output_tokens ); } } diff --git a/vall_e.cpp/vall_e.h b/vall_e.cpp/vall_e.h index a592413..e31949d 100644 --- a/vall_e.cpp/vall_e.h +++ b/vall_e.cpp/vall_e.h @@ -12,11 +12,11 @@ // to-do: copy over import/export stuff from engine project (because I don't remember how I set it up in ) #define VALL_E_API -#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions -#define LLAMA_CPP_USE_VALL_E_ARCH 1 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch) +#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions +#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch) #if !LLAMA_CPP_EXTENDED - #include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd + #include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd #endif // to-do: clean up spaghetti enums @@ -75,44 +75,45 @@ struct input_t { [(16677, 17702), 'resps_emb.embeddings.8.weight', 'classifiers.proj.8.weight', '<|R|NAR|0:0|{id}|>'] */ -// handles all the cringe logic of slicing embeddings -struct ranges_t { - std::string name; - - uint32_t start; - uint32_t end; - - int32_t classifier_idx = -1; -}; - // stores embeddings + metadata for an embedding range -struct embeddings_t { +struct io_t { + std::string name; + uint32_t start; + uint32_t end; + int32_t head_idx = -1; + int32_t n_embd = 0; int32_t n_vocab = 0; - ranges_t range = {}; std::vector embds = {}; + ggml_tensor* head = NULL; }; // stores the mappings between tokens, input embeddings, and output heads -struct inputs_map_t { +struct io_map_t { + // model's original params int32_t n_embd = 0; int32_t n_vocab = 0; // mapping - std::unordered_map embds = {}; + std::unordered_map io = {}; + // context to store slices + ggml_context* ctx = NULL; }; // helper tensor functions std::vector VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ); +ggml_tensor* VALL_E_API view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket +ggml_tensor* VALL_E_API view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); + std::vector> VALL_E_API map_embeddings( const std::vector& tokens, int n_embd, const float* embds ); std::vector> VALL_E_API sum_embeddings( const std::vector>& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM ); std::vector VALL_E_API soft_max( int n_logits, const float* logits ); // batch and inferencing void VALL_E_API batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector & seq_ids = {0} ); -void VALL_E_API fill_batch( llama_batch& batch, input_t& input, inputs_map_t& inputs_map, int mode ); -std::vector VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, inputs_map_t& inputs_map, int max_tokens, int mode, bool verbose = true ); +void VALL_E_API fill_batch( llama_batch& batch, input_t& input, io_map_t& inputs_map, int mode ); +std::vector VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, io_map_t& inputs_map, int max_tokens, int mode, bool verbose = true ); // encodec helpers bool VALL_E_API read_wav_from_disk( std::string in_path, std::vector& audio_arr ); @@ -121,10 +122,10 @@ std::vector> VALL_E_API encode_audio_from_disk( struct enco std::vector VALL_E_API decode_audio( struct encodec_context* ectx, const std::vector>& codes_2d ); // model-accessing helpers -const embeddings_t& VALL_E_API vall_e_inputs_map_get_embeddings( inputs_map_t& inputs_map, const std::string& name ); -const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( inputs_map_t& inputs_map, const std::string& name ); -int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( inputs_map_t& inputs_map, const std::string& name ); -void VALL_E_API vall_e_inputs_map_init( inputs_map_t&, llama_model* model ); +const io_t& VALL_E_API vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name ); +const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name ); +int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name ); +void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model ); struct ggml_tensor * VALL_E_API vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx ); struct ggml_tensor * VALL_E_API vall_e_get_resp_embds( llama_vall_e_userdata& userdata, int32_t idx );