more work on vall_e.cpp (need to resolve why the embeddings (and maybe the weights as a whole) are different from the base model)

This commit is contained in:
mrq 2024-12-23 20:36:40 -06:00
parent 6ecdb715b6
commit f62f99b8de
4 changed files with 361 additions and 175 deletions

View File

@ -16,12 +16,13 @@ Run `make`.
[`encodec.cpp`](https://github.com/e-c-k-e-r/encodec.cpp) requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working.
[`llama.cpp`](https://github.com/e-c-k-e-r/llama.cpp) *might* not require any modifications, but implementing `LLM_ARCH_VALL_E` requires some surgery.
[`llama.cpp`](https://github.com/e-c-k-e-r/llama.cpp) only possible modification needs to ensure that a non-causal attention mask is used; everything necessary can be hacked together with clever tricks.
## To-Do
* [x] converted model to GGUF
* [ ] convert it without modifying any of the existing code, as the tokenizer requires some care
* [ ] *actually* convert the model properly, as the embeddings differ from the real model
* [x] basic framework
* [x] load the quantized model
* [x] orchestrate the required embeddings
@ -35,13 +36,10 @@ Run `make`.
* [x] sum embeddings for the `prom` and prior `resp`s
* [ ] working `AR` output
* [x] `AR` sampling
* currently need a model that didn't regress with the `AR:0:0` output
* [ ] working `NAR-len` output
* [x] `NAR-len` sampling
* need to assert that a non-causal mask is used
* [ ] working `NAR` output
* [x] `NAR` sampling
* need to assert that a non-causal mask is used
* [x] decode audio to disk
* [ ] a functional CLI
* [ ] actually make it work

View File

@ -0,0 +1,163 @@
#pragma once
#include "llama-vocab.h"
#include <array>
/* Begin cringe so I can access the model's tok_embd */
// it needs to be copied so the struct layout is exactly as it is under llama.cpp
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
enum e_model {
MODEL_UNKNOWN,
};
enum llm_arch {
LLM_ARCH_UNKNOWN,
};
struct llama_hparams_posnet {
uint32_t n_embd;
uint32_t n_layer;
};
struct llama_hparams_convnext {
uint32_t n_embd;
uint32_t n_layer;
};
struct llama_hparams {
bool vocab_only;
bool rope_finetuned;
bool use_par_res;
bool swin_norm;
uint32_t n_vocab = 0;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0;
// for WavTokenizer
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0;
uint32_t n_lora_kv = 0;
uint32_t n_ff_exp = 0;
uint32_t n_ff_shexp = 0;
uint32_t n_expert_shared = 0;
float expert_weights_scale = 0.0;
float f_norm_eps;
float f_norm_rms_eps;
float f_norm_group_eps;
uint32_t n_norm_groups;
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
// for RWKV
uint32_t rescale_every_n_layers = 0;
uint32_t time_mix_extra_dim = 0;
uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0;
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
int rope_sections[4];
// for State Space Models
uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
bool ssm_dt_b_c_rms = false;
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;
// Additional scale factors (Granite/Granite MoE)
float f_residual_scale = 0.0f;
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
};
struct llama_model {
e_model type = MODEL_UNKNOWN;
llm_arch arch = LLM_ARCH_UNKNOWN;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
std::string name = "n/a";
llama_hparams hparams = {};
llama_vocab vocab;
struct ggml_tensor * tok_embd = nullptr;
struct ggml_tensor * type_embd = nullptr;
struct ggml_tensor * pos_embd = nullptr;
struct ggml_tensor * tok_norm = nullptr;
struct ggml_tensor * tok_norm_b = nullptr;
struct ggml_tensor * output_norm = nullptr;
struct ggml_tensor * output_norm_b = nullptr;
struct ggml_tensor * output = nullptr;
struct ggml_tensor * output_b = nullptr;
struct ggml_tensor * output_norm_enc = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;
struct ggml_tensor * cls_out = nullptr;
struct ggml_tensor * cls_out_b = nullptr;
struct ggml_tensor * conv1d = nullptr;
struct ggml_tensor * conv1d_b = nullptr;
};
/* BEGIN VALL-E SPECIFIC HELPERS */
struct ggml_tensor * llama_get_embedding_weights(struct llama_model * model) {
return model->tok_embd;
}
struct ggml_tensor * llama_get_output_head_tensor(struct llama_model * model ) {
return model->output;
}
void llama_set_output_head(struct llama_model * model, struct ggml_tensor* tensor ) {
// set the output tensor
model->output = tensor;
// required to properly output logits
*const_cast<uint32_t*>(&model->hparams.n_vocab) = tensor->ne[1];
}
/* END VALL-E SPECIFIC HELPERS */
/* End cringe code */

View File

@ -7,7 +7,7 @@
#include <iostream>
#include <algorithm>
ranges_t io_ranges[] = {
io_t io_ranges[] = {
{ "text", 0, 256, 9, },
{ "rvq_l", 256, 264, -1, },
{ "lang", 264, 270, -1, },
@ -25,15 +25,15 @@ ranges_t io_ranges[] = {
{ "prom|6", 6436, 7460, -1, },
{ "prom|7", 7460, 8484, -1, },
{ "resps|AR:0 8484, 9509, 0,:0", },
{ "resps|NAR:0 9509, 10533, 1,:1", },
{ "resps|NAR:1: 10533, 11557, 2,2", },
{ "resps|NAR:2: 11557, 12581, 3,3", },
{ "resps|NAR:3: 12581, 13605, 4,4", },
{ "resps|NAR:4: 13605, 14629, 5,5", },
{ "resps|NAR:5: 14629, 15653, 6,6", },
{ "resps|NAR:6: 15653, 16677, 7,7", },
{ "resps|NAR:0: 16677, 17702, 8,0", },
{ "resps|AR:0:0", 8484, 9509, 0 },
{ "resps|NAR:0:1", 9509, 10533, 1 },
{ "resps|NAR:1:2", 10533, 11557, 2 },
{ "resps|NAR:2:3", 11557, 12581, 3 },
{ "resps|NAR:3:4", 12581, 13605, 4 },
{ "resps|NAR:4:5", 13605, 14629, 5 },
{ "resps|NAR:5:6", 14629, 15653, 6 },
{ "resps|NAR:6:7", 15653, 16677, 7 },
{ "resps|NAR:0:0", 16677, 17702, 8 },
};
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
@ -51,6 +51,43 @@ std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
return res;
}
ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
// to-do: implement other dim
if ( start < 0 ) start = tensor->ne[1] + start;
if ( end < 0 ) end = tensor->ne[1] + end;
ggml_tensor* res = new ggml_tensor();
memcpy( res, tensor, sizeof(ggml_tensor) );
res->op = GGML_OP_VIEW;
res->src[0] = tensor;
res->data += res->nb[1] * start;
res->ne[1] = end - start;
for (int i = 2; i < GGML_MAX_DIMS; i++) {
res->nb[i] = res->nb[i - 1] * res->ne[i - 1];
}
return res;
}
ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_context* ctx, struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
// to-do: implement other dim
if ( start < 0 ) start = tensor->ne[1] + start;
if ( end < 0 ) end = tensor->ne[1] + end;
ggml_tensor* res = ggml_view_2d( ctx, tensor, tensor->ne[0], end - start, tensor->nb[1], tensor->nb[1] * start );
/*
printf("%p: %i | %i | %i | %i || %p: %i | %i | %i | %i\n",
tensor->data, tensor->ne[0], tensor->ne[1], tensor->nb[1], tensor->nb[2],
res->data, res->ne[0], res->ne[1], res->nb[1], res->nb[2]
);
*/
return res;
}
struct ggml_tensor * VALL_E_API vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
return userdata.prom_embds[idx];
@ -63,81 +100,100 @@ struct ggml_tensor * VALL_E_API vall_e_get_aux_embds( llama_vall_e_userdata& us
}
const embeddings_t& VALL_E_API vall_e_inputs_map_get_embeddings( inputs_map_t& inputs_map, const std::string& name ) {
return inputs_map.embds[name];
const io_t& VALL_E_API vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) {
return io_map.io[name];
}
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( inputs_map_t& inputs_map, const std::string& name ) {
return inputs_map.embds[name].embds.data();
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& io_map, const std::string& name ) {
return io_map.io[name].embds.data();
}
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( inputs_map_t& inputs_map, const std::string& name ) {
return inputs_map.embds[name].range.classifier_idx;
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& io_map, const std::string& name ) {
return io_map.io[name].head_idx;
}
void VALL_E_API vall_e_inputs_map_init( inputs_map_t& inputs_map, llama_model* model ) {
void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
auto n_embd = llama_n_embd( model );
auto n_vocab = llama_n_vocab( model );
inputs_map.n_embd = n_embd;
inputs_map.n_vocab = n_vocab;
io_map.n_embd = n_embd;
io_map.n_vocab = n_vocab;
auto& userdata = *llama_get_vall_e_userdata( model );
int32_t ctx_size = 24 * 2 * ggml_tensor_overhead(); // 24 embeddings + 24 output heads (generous) (should only really need to do this for output heads since we manually handle embeddings)
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
io_map.ctx = ggml_init(params);
// to-do: figure a nicer way to do this
#if LLAMA_CPP_USE_VALL_E_ARCH
inputs_map.embds["text"] = { n_embd, 0, { "text", 0, 0, 9, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 0)) };
inputs_map.embds["rvq_l"] = { n_embd, 0, { "rvq_l", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 1)) };
inputs_map.embds["lang"] = { n_embd, 0, { "lang", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 2)) };
inputs_map.embds["task"] = { n_embd, 0, { "task", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 3)) };
inputs_map.embds["len"] = { n_embd, 0, { "len", 0, 0, 10, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 4)) };
inputs_map.embds["tone"] = { n_embd, 0, { "tone", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 5)) };
inputs_map.embds["sep"] = { n_embd, 0, { "sep", 0, 0, -1, }, read_2d_tensor(vall_e_get_aux_embds(userdata, 6)) };
auto& userdata = *llama_get_vall_e_userdata( model );
inputs_map.embds["prom|0"] = { n_embd, 0, { "prom|0", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 0)) };
inputs_map.embds["prom|1"] = { n_embd, 0, { "prom|1", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 1)) };
inputs_map.embds["prom|2"] = { n_embd, 0, { "prom|2", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 2)) };
inputs_map.embds["prom|3"] = { n_embd, 0, { "prom|3", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 3)) };
inputs_map.embds["prom|4"] = { n_embd, 0, { "prom|4", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 4)) };
inputs_map.embds["prom|5"] = { n_embd, 0, { "prom|5", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 5)) };
inputs_map.embds["prom|6"] = { n_embd, 0, { "prom|6", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 6)) };
inputs_map.embds["prom|7"] = { n_embd, 0, { "prom|7", 0, 0, -1, }, read_2d_tensor(vall_e_get_prom_embds(userdata, 7)) };
for ( auto& entry : io_ranges ) {
io_map.io[entry.name] = entry;
io_map.io[entry.name].n_embd = n_embd;
io_map.io[entry.name].n_vocab = entry.end - entry.start;
io_map.io[entry.name].start = 0;
io_map.io[entry.name].end = 0;
io_map.io[entry.name].head = entry.head_idx < 0 ? NULL : userdata.heads[entry.head_idx];
}
io_map.io["text"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 0));
io_map.io["rvq_l"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 1));
io_map.io["lang"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 2));
io_map.io["task"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 3));
io_map.io["len"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 4));
io_map.io["tone"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 5));
io_map.io["sep"].embds = read_2d_tensor(vall_e_get_aux_embds(userdata, 6));
io_map.io["prom|0"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 0));
io_map.io["prom|1"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 1));
io_map.io["prom|2"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 2));
io_map.io["prom|3"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 3));
io_map.io["prom|4"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 4));
io_map.io["prom|5"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 5));
io_map.io["prom|6"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 6));
io_map.io["prom|7"].embds = read_2d_tensor(vall_e_get_prom_embds(userdata, 7));
inputs_map.embds["resps|AR:0:0"] = { n_embd, 0, { "resps|AR:0:0", 0, 0, 0, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 0)) };
inputs_map.embds["resps|NAR:0:1"] = { n_embd, 0, { "resps|NAR:0:1", 0, 0, 1, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 1)) };
inputs_map.embds["resps|NAR:1:2"] = { n_embd, 0, { "resps|NAR:1:2", 0, 0, 2, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 2)) };
inputs_map.embds["resps|NAR:2:3"] = { n_embd, 0, { "resps|NAR:2:3", 0, 0, 3, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 3)) };
inputs_map.embds["resps|NAR:3:4"] = { n_embd, 0, { "resps|NAR:3:4", 0, 0, 4, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 4)) };
inputs_map.embds["resps|NAR:4:5"] = { n_embd, 0, { "resps|NAR:4:5", 0, 0, 5, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 5)) };
inputs_map.embds["resps|NAR:5:6"] = { n_embd, 0, { "resps|NAR:5:6", 0, 0, 6, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 6)) };
inputs_map.embds["resps|NAR:6:7"] = { n_embd, 0, { "resps|NAR:6:7", 0, 0, 7, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 7)) };
inputs_map.embds["resps|NAR:0:0"] = { n_embd, 0, { "resps|NAR:0:0", 0, 0, 8, }, read_2d_tensor(vall_e_get_resp_embds(userdata, 8)) };
io_map.io["resps|AR:0:0"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 0));
io_map.io["resps|NAR:0:1"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 1));
io_map.io["resps|NAR:1:2"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 2));
io_map.io["resps|NAR:2:3"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 3));
io_map.io["resps|NAR:3:4"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 4));
io_map.io["resps|NAR:4:5"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 5));
io_map.io["resps|NAR:5:6"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 6));
io_map.io["resps|NAR:6:7"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 7));
io_map.io["resps|NAR:0:0"].embds = read_2d_tensor(vall_e_get_resp_embds(userdata, 8));
// update values
for ( auto& pair : inputs_map.embds ) {
auto& k = pair.first;
auto& v = pair.second;
auto& embds = v.embds;
v.n_vocab = embds.size() / n_embd;
v.range.end = v.n_vocab;
for ( auto& entry : io_ranges ) {
for ( auto i = 0; i < 32; ++i ) printf("%s: %i: %f\n", entry.name.c_str(), i, io_map.io[entry.name].embds[i] );
}
#else
#if LLAMA_CPP_EXTENDED
auto* tensor = llama_get_embedding_weights( model );
#else
auto* tensor = model->tok_embd;
#endif
auto* embds = llama_get_embedding_weights( model );
auto* heads = llama_get_output_head_tensor( model );
// prepare slices
std::vector<float> raw_embeddings = read_2d_tensor( tensor );
for ( auto& range : io_ranges ) {
inputs_map.embds[range.name] = {
n_embd,
range.end - range.start,
range,
std::vector<float>( raw_embeddings.data() + range.start, raw_embeddings.data() + range.end )
};
// std::vector<float> raw_embeddings = read_2d_tensor( embds );
for ( auto& entry : io_ranges ) {
io_map.io[entry.name] = entry;
io_map.io[entry.name].n_embd = n_embd;
io_map.io[entry.name].n_vocab = entry.end - entry.start;
io_map.io[entry.name].embds = read_2d_tensor(view_2d_tensor( io_map.ctx, embds, entry.start, entry.end ));
io_map.io[entry.name].head = entry.head_idx < 0 ? NULL : view_2d_tensor( io_map.ctx, heads, entry.start, entry.end );
// these two differ after the first embedding and I don't know why.........
/*
auto raw_embd = std::vector<float>( raw_embeddings.data() + entry.start * n_embd, raw_embeddings.data() + entry.end * n_embd );
auto sliced_embd = read_2d_tensor( embd_tensor );
io_map.io[entry.name].embds = raw_embd;
for ( auto i = 0; i < 32; ++i ) printf("%s: %i: %f == %f \n", entry.name.c_str(), i, raw_embd[i], sliced_embd[i] );
*/
}
#endif
}
@ -323,38 +379,38 @@ std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
return res;
}
void VALL_E_API fill_batch( llama_batch& batch, input_t& input, inputs_map_t& inputs_map, int mode ) {
void VALL_E_API fill_batch( llama_batch& batch, input_t& input, io_map_t& io_map, int mode ) {
// keeps track of the position for each sequence
size_t pos = 0;
auto n_embd = inputs_map.n_embd;
auto n_embd = io_map.n_embd;
const float* text_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "text");
const float* rvq_l_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "rvq_l");
const float* lang_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "lang");
const float* task_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "task");
const float* len_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "len");
const float* tone_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "tone");
const float* sep_embds = vall_e_inputs_map_get_embeddings_p(inputs_map, "sep");
const float* text_embds = vall_e_inputs_map_get_embeddings_p(io_map, "text");
const float* rvq_l_embds = vall_e_inputs_map_get_embeddings_p(io_map, "rvq_l");
const float* lang_embds = vall_e_inputs_map_get_embeddings_p(io_map, "lang");
const float* task_embds = vall_e_inputs_map_get_embeddings_p(io_map, "task");
const float* len_embds = vall_e_inputs_map_get_embeddings_p(io_map, "len");
const float* tone_embds = vall_e_inputs_map_get_embeddings_p(io_map, "tone");
const float* sep_embds = vall_e_inputs_map_get_embeddings_p(io_map, "sep");
const float* prom_embds[] = {
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|0"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|1"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|2"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|3"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|4"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|5"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|6"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "prom|7"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|0"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|1"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|2"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|3"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|4"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|5"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|6"),
vall_e_inputs_map_get_embeddings_p(io_map, "prom|7"),
};
const float* resp_embds[] = {
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|AR:0:0"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:0:1"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:1:2"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:2:3"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:3:4"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:4:5"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:5:6"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:6:7"),
vall_e_inputs_map_get_embeddings_p(inputs_map, "resps|NAR:0:0"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|AR:0:0"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:0:1"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:1:2"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:2:3"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:3:4"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:4:5"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:5:6"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:6:7"),
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:0:0"),
};
// insert text tokens
@ -394,20 +450,13 @@ void VALL_E_API fill_batch( llama_batch& batch, input_t& input, inputs_map_t& in
}
// generation code, should handle all modalities easily
std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, inputs_map_t& inputs_map, int max_tokens, int mode, bool verbose ) {
int rvq_l = input.rvq_l;
llama_token stop_token = -1;
int n_decode = 0; // number of tokens decoded
int n_outputs = 0; // number of output tokens to expect
int n_vocab = 0;
int n_embd = 0;
std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, io_map_t& io_map, int max_tokens, int mode, bool verbose ) {
bool causal = true; // sample autoregressively or not
const float* embds = NULL; // embeddings to map output tokens through
ranges_t range; // I/O range
int n_outputs = 0; // number of output tokens to expect
// create batch (targetting embeddings instead of tokens)
llama_batch batch = llama_batch_init( CTX_SIZE, inputs_map.n_embd, CTX_SIZE );
fill_batch( batch, input, inputs_map, mode );
llama_batch batch = llama_batch_init( CTX_SIZE, io_map.n_embd, CTX_SIZE );
fill_batch( batch, input, io_map, mode );
// determine how many outputs we need
for ( auto i = 0; i < batch.n_tokens; ++i ) {
@ -438,7 +487,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
"resps|NAR:5:6",
"resps|NAR:6:7",
};
embd_name = k_embds[rvq_l];
embd_name = k_embds[input.rvq_l];
// duration inferencing mode
} else if ( mode == INFERENCE_MODE_LEN ) {
embd_name = "len";
@ -447,22 +496,18 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
embd_name = "resps|NAR:0:0";
}
auto& embeddings = vall_e_inputs_map_get_embeddings(inputs_map, embd_name);
range = embeddings.range;
embds = embeddings.embds.data();
n_embd = embeddings.n_embd;
n_vocab = embeddings.n_vocab;
stop_token = range.end - range.start - 1;
auto& io = vall_e_inputs_map_get(io_map, embd_name);
const float* embds = io.embds.data();
printf("Generating in %s (%i) mode (%i:%i) (%i)\n", embd_name.c_str(), range.classifier_idx, range.start, range.end, stop_token);
int32_t n_embd = io.n_embd;
int32_t n_vocab = io.n_vocab;
llama_token stop_token = io.end - io.start - 1;
printf("Generating in %s (%i) mode (%i:%i) (%i)\n", embd_name.c_str(), io.head_idx, io.start, io.end, stop_token);
// update model's output heads / causal mode
#if LLAMA_CPP_USE_VALL_E_ARCH
auto& userdata = *llama_get_vall_e_userdata( model );
llama_set_output_head( model, userdata.heads[range.classifier_idx] );
#endif
llama_set_causal_attn( ctx, causal );
// to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
llama_set_output_head( model, io.head );
llama_set_causal_attn( ctx, causal ); // to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
std::vector<llama_token> output_tokens;
const auto t_main_start = ggml_time_us();
@ -480,13 +525,6 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
return output_tokens;
}
// ensures only tokens within our designated range are used
#if !LLAMA_CPP_USE_VALL_E_ARCH
auto* logits = llama_get_logits_ith( ctx, -1 );
for ( auto i = 0; i < inputs_map.n_vocab; ++i ) {
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
}
#endif
// sample token
auto t = llama_sampler_sample(smpl, ctx, -1);
@ -498,7 +536,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
// store token
output_tokens.emplace_back(t);
// update batch with token
batch_add( batch, t, inputs_map.n_embd, embds, output_tokens.size(), true );
batch_add( batch, t, io_map.n_embd, embds, output_tokens.size(), true );
if ( verbose ) {
printf("%i, ", t);
fflush(stdout);
@ -527,7 +565,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
null_input.phn = {1, 2}; // <bos></eos>
null_input.resp.resize(1);
llama_batch null_batch = llama_batch_init( CTX_SIZE, inputs_map.n_embd, CTX_SIZE );
llama_batch null_batch = llama_batch_init( CTX_SIZE, io_map.n_embd, CTX_SIZE );
// token scores to reference for masking
std::vector<float> scores(n_outputs, 1.0);
@ -567,11 +605,11 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
// to-do: only update the embeddings instead
batch.n_tokens = 0;
input.resp[0] = output_tokens;
fill_batch( batch, input, inputs_map, mode );
fill_batch( batch, input, io_map, mode );
// update null batch
null_input.resp[0] = output_tokens;
null_batch.n_tokens = 0;
fill_batch( null_batch, input, inputs_map, mode );
fill_batch( null_batch, input, io_map, mode );
// to-do: update sampling temperature
@ -602,11 +640,6 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
auto* null_logit = &null_logits[idx];
#if !LLAMA_CPP_USE_VALL_E_ARCH
for ( auto i = 0; i < inputs_map.n_vocab; ++i ) {
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
}
#endif
// perform softmax before modifying logits
std::vector<float> softmaxed = soft_max( n_vocab, logits );
@ -645,14 +678,6 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
fflush(stdout);
}
for ( auto idx = 0; idx < n_outputs; ++idx ) {
// ensures only tokens within our designated range are used
#if !LLAMA_CPP_USE_VALL_E_ARCH
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
for ( auto i = 0; i < inputs_map.n_vocab; ++i ) {
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
}
#endif
// sample ith token
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
@ -674,7 +699,7 @@ std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* m
if ( verbose ) {
printf("\n");
fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
__func__, output_tokens.size(), (t_main_end - t_main_start) / 1000000.0f, output_tokens.size() / ((t_main_end - t_main_start) / 1000000.0f));
fprintf(stderr, "\n");
llama_perf_sampler_print(smpl);
@ -692,7 +717,7 @@ int main( int argc, char** argv ) {
int32_t ngl = 0;
int modality = MODALITY_NAR_LEN;
input_t input{};
inputs_map_t inputs_map{};
io_map_t io_map{};
// input.phonemes = "hˈɛloː ʋˈɔrlt";
input.phn = {1,22,111,100,4,37,115,169,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
@ -725,7 +750,6 @@ int main( int argc, char** argv ) {
ctx_params.no_perf = false;
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
// create two contexts, one's that causally, the other that isn't, because pain
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@ -765,7 +789,7 @@ int main( int argc, char** argv ) {
auto n_vocab = llama_n_vocab( model );
// grab input embeddings
vall_e_inputs_map_init( inputs_map, model );
vall_e_inputs_map_init( io_map, model );
// tokenize phonemes
// to-do: make this work, the vocab does not work
@ -787,10 +811,10 @@ int main( int argc, char** argv ) {
// NAR-len demasking
if ( modality == MODALITY_NAR_LEN ) {
// inference len
int len = 290;
int len = 0;
if ( !len ) {
input.task = "len";
output_tokens = generate( ctx, model, smpl_nar, input, inputs_map, 5, INFERENCE_MODE_LEN );
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, INFERENCE_MODE_LEN );
{
int digit = 1;
for (int i = output_tokens.size() - 1; i >= 0; i--) {
@ -812,7 +836,7 @@ int main( int argc, char** argv ) {
input.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l;
output_tokens = generate( ctx, model, smpl_nar, input, inputs_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
output_tokens = generate( ctx, model, smpl_nar, input, io_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
input.resp.emplace_back( output_tokens );
}
// AR+NAR
@ -820,7 +844,7 @@ int main( int argc, char** argv ) {
input.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l;
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, inputs_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, io_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
input.resp.emplace_back( output_tokens );
}
}

View File

@ -12,11 +12,11 @@
// to-do: copy over import/export stuff from engine project (because I don't remember how I set it up in <uf/config.h>)
#define VALL_E_API
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 1 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
#if !LLAMA_CPP_EXTENDED
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
#include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
#endif
// to-do: clean up spaghetti enums
@ -75,44 +75,45 @@ struct input_t {
[(16677, 17702), 'resps_emb.embeddings.8.weight', 'classifiers.proj.8.weight', '<|R|NAR|0:0|{id}|>']
*/
// handles all the cringe logic of slicing embeddings
struct ranges_t {
std::string name;
uint32_t start;
uint32_t end;
int32_t classifier_idx = -1;
};
// stores embeddings + metadata for an embedding range
struct embeddings_t {
struct io_t {
std::string name;
uint32_t start;
uint32_t end;
int32_t head_idx = -1;
int32_t n_embd = 0;
int32_t n_vocab = 0;
ranges_t range = {};
std::vector<float> embds = {};
ggml_tensor* head = NULL;
};
// stores the mappings between tokens, input embeddings, and output heads
struct inputs_map_t {
struct io_map_t {
// model's original params
int32_t n_embd = 0;
int32_t n_vocab = 0;
// mapping
std::unordered_map<std::string, embeddings_t> embds = {};
std::unordered_map<std::string, io_t> io = {};
// context to store slices
ggml_context* ctx = NULL;
};
// helper tensor functions
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor );
ggml_tensor* VALL_E_API view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket
ggml_tensor* VALL_E_API view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 );
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<llama_token>& tokens, int n_embd, const float* embds );
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits );
// batch and inferencing
void VALL_E_API batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} );
void VALL_E_API fill_batch( llama_batch& batch, input_t& input, inputs_map_t& inputs_map, int mode );
std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, inputs_map_t& inputs_map, int max_tokens, int mode, bool verbose = true );
void VALL_E_API fill_batch( llama_batch& batch, input_t& input, io_map_t& inputs_map, int mode );
std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, io_map_t& inputs_map, int max_tokens, int mode, bool verbose = true );
// encodec helpers
bool VALL_E_API read_wav_from_disk( std::string in_path, std::vector<float>& audio_arr );
@ -121,10 +122,10 @@ std::vector<std::vector<int32_t>> VALL_E_API encode_audio_from_disk( struct enco
std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes_2d );
// model-accessing helpers
const embeddings_t& VALL_E_API vall_e_inputs_map_get_embeddings( inputs_map_t& inputs_map, const std::string& name );
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( inputs_map_t& inputs_map, const std::string& name );
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( inputs_map_t& inputs_map, const std::string& name );
void VALL_E_API vall_e_inputs_map_init( inputs_map_t&, llama_model* model );
const io_t& VALL_E_API vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name );
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name );
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name );
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );
struct ggml_tensor * VALL_E_API vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx );
struct ggml_tensor * VALL_E_API vall_e_get_resp_embds( llama_vall_e_userdata& userdata, int32_t idx );