added extremely barebones vall_e.cpp so I can stop having to juggle this file around so much

This commit is contained in:
mrq 2024-12-21 10:57:02 -06:00
parent 91caf00212
commit 5788db849b
5 changed files with 391 additions and 11 deletions

View File

@ -99,7 +99,12 @@
}
}
},
"decoder": null,
"decoder": {
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": false,
"use_regex": false
},
"model": {
"type": "BPE",
"dropout": null,
@ -107,7 +112,7 @@
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": false,
"byte_fallback": true,
"ignore_merges": false,
"vocab": {
"<unk>": 0,

30
vall_e.cpp/README.md Normal file
View File

@ -0,0 +1,30 @@
# vall_e.cpp
This is an implementation that makes use of [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [encodec.cpp](https://github.com/PABannier/encodec.cpp).
At the moment it's ***very*** barebones as I try and wrestle with `llama.cpp`'s API without needing to modify its code.
## Build
Probably something like:
`g++ -I/path/to/llama.cpp/include/ -L/path/to/llama.cpp/libllama.so -lggml -lggml-base -lllama -o ./vall_e`
## To-Do
* [x] converted model to GGUF
* [ ] convert it without modifying any of the existing code
* [x] basic framework
* [x] load the model
* [x] orchestrate the required embeddings
* [x] juggle the output head / classifier properly
* [ ] phonemize text
* [ ] tokenize phonemes
* [ ] load audio from disk
* [ ] encode audio
* [ ] sum embeddings for the `prom` and prior `resp`s
* [ ] `NAR-len` demasking sampling
* [ ] `NAR` sampling
* [ ] decode audio to disk
* [ ] a functional CLI
* [ ] quantize the model (properly)

336
vall_e.cpp/vall_e.cpp Normal file
View File

@ -0,0 +1,336 @@
#include "llama-vocab.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
#include <array>
#include <iostream>
/* Begin cringe so I can access the model's tok_embd */
// it needs to be copied so the struct layout is exactly as it is under llama.cpp
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
enum e_model {
MODEL_UNKNOWN,
};
enum llm_arch {
LLM_ARCH_UNKNOWN,
};
struct llama_hparams_posnet {
uint32_t n_embd;
uint32_t n_layer;
};
struct llama_hparams_convnext {
uint32_t n_embd;
uint32_t n_layer;
};
struct llama_hparams {
bool vocab_only;
bool rope_finetuned;
bool use_par_res;
bool swin_norm;
uint32_t n_vocab = 0;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0;
// for WavTokenizer
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0;
uint32_t n_lora_kv = 0;
uint32_t n_ff_exp = 0;
uint32_t n_ff_shexp = 0;
uint32_t n_expert_shared = 0;
float expert_weights_scale = 0.0;
float f_norm_eps;
float f_norm_rms_eps;
float f_norm_group_eps;
uint32_t n_norm_groups;
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
// for RWKV
uint32_t rescale_every_n_layers = 0;
uint32_t time_mix_extra_dim = 0;
uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0;
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
int rope_sections[4];
// for State Space Models
uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
bool ssm_dt_b_c_rms = false;
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;
// Additional scale factors (Granite/Granite MoE)
float f_residual_scale = 0.0f;
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
};
struct llama_model {
e_model type = MODEL_UNKNOWN;
llm_arch arch = LLM_ARCH_UNKNOWN;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
std::string name = "n/a";
llama_hparams hparams = {};
llama_vocab vocab;
struct ggml_tensor * tok_embd = nullptr;
};
/* End cringe code */
// handles adding either a token OR the embedding of that token into the batch
// this really, really helps avoid needing to abuse the tokenizer
// to-do: handle summing
void batch_add( struct llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool logits = true, const std::vector<llama_seq_id> & seq_ids = {0} ) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
if ( embds ) {
for ( auto i = 0; i < n_embd; ++i ) {
batch.embd[batch.n_tokens + i] = embds[id * n_embd + i];
}
} else {
batch.token[batch.n_tokens] = id;
}
batch.pos[batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits[batch.n_tokens] = logits;
batch.n_tokens++;
}
int main(int argc, char ** argv) {
bool is_ar = true;
// to-do: replace all of this with proper loading code
std::vector<llama_token> phoneme_tokens = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2};
llama_token lang_token = 0;
llama_token rvq_level_token = 0;
std::vector<std::vector<llama_token>> prompt_tokens = {
{780,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,464,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,398,869,639,705,869,917,705,893,215,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,408,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,917,238,745,904,904,904,106,133,1019,1017,1017,395,883,87,519,594,1002,682,996,540,186,1019,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,25,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,159,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,62,172,638,359,562,138,839,846,775,556,688,1006,917,297,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,538,519,762,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,310,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,368,359,519,996,616,464,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,876,876,1017,63,329},
};
std::vector<std::vector<llama_token>> response_tokens = {
{922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665},
};
std::string model_path = "./vall_e/Vall_E-238M-F16.gguf";
// load dynamic backends
ggml_backend_load_all();
// initialize the model
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 0;
llama_model* model = llama_load_model_from_file(model_path.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 22500;
ctx_params.n_batch = 22500;
ctx_params.no_perf = false;
ctx_params.attention_type = is_ar ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL;
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
// initialize the sampler
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
// prepare batch
auto n_embd = llama_n_embd( model );
auto n_vocab = llama_n_vocab( model );
// float* embd = (float*) llama_get_embedding_weights( model )->data;
float* embds = (float*) (model->tok_embd->data);
// to-do: derive these offsets from the tokenizer itself
float* text_embds = embds + (0 * n_embd); // <bos>
float* rvq_level_embd = embds + (17666 * n_embd); // <|RVQ:0>
float* len_embd = embds + (17674 * n_embd); // <|len:0|>
float* lang_embd = embds + (17686 * n_embd); // <|lang:en|>
float* task_embd = embds + (17692 * n_embd); // <|task:tts|>
float* sep_embd = embds + (17685 * n_embd); // <|sep|>
float* prom_embds[] = {
embds + (256 + (1024 * 0) * n_embd), // <|P|0:0|>
embds + (256 + (1024 * 1) * n_embd), // <|P|1:0|>
embds + (256 + (1024 * 2) * n_embd), // <|P|2:0|>
embds + (256 + (1024 * 3) * n_embd), // <|P|3:0|>
embds + (256 + (1024 * 4) * n_embd), // <|P|4:0|>
embds + (256 + (1024 * 5) * n_embd), // <|P|5:0|>
embds + (256 + (1024 * 6) * n_embd), // <|P|6:0|>
embds + (256 + (1024 * 7) * n_embd), // <|P|7:0|>
};
float* resps_embds[] = {
embds + (8448 * n_embd), // <|AR|0:0|>
embds + (9473 * n_embd), // <|NAR|0:0|>
embds + (10498 + (1024 * 0) * n_embd), // <|NAR|0:1|>
embds + (10498 + (1024 * 1) * n_embd), // <|NAR|1:2|>
embds + (10498 + (1024 * 2) * n_embd), // <|NAR|2:3|>
embds + (10498 + (1024 * 3) * n_embd), // <|NAR|3:4|>
embds + (10498 + (1024 * 4) * n_embd), // <|NAR|4:5|>
embds + (10498 + (1024 * 5) * n_embd), // <|NAR|5:6|>
embds + (10498 + (1024 * 6) * n_embd), // <|NAR|6:7|>
};
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
{
// keeps track of the position for each sequence
size_t pos = 0;
// insert text tokens
for ( auto& id : phoneme_tokens ) {
batch_add( batch, id, n_embd, text_embds, pos++, false );
}
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0;
// insert lang token
batch_add( batch, lang_token, n_embd, lang_embd, pos++, false );
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0;
// insert rvq level token
batch_add( batch, rvq_level_token, n_embd, rvq_level_embd, pos++, false );
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0;
// insert prom tokens
// to-do: handle summing
for ( auto l = 0; l < prompt_tokens.size(); ++l ) {
for ( auto& id : prompt_tokens[l] ) {
batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
}
}
batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar );
pos = 0;
// fill in masked tokens
if ( !is_ar ) {
for ( auto i = 0; i < response_tokens[0].size(); ++i ) {
batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
}
}
pos = 0;
}
// Decoding loop
const auto t_main_start = ggml_time_us();
int n_decode = 0;
// to-do: handle other levels
std::vector<llama_token> resps_tokens;
while ( resps_tokens.size() < 32 ) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
n_decode += 1;
// align to AR's classifier
// to-do: derive from tokenizer
int range[] = { 8448, 8448 + 1024 }; // { <|AR|0:0|>, <|AR|0:STOP|> }
auto* logits = llama_get_logits_ith( ctx, -1 );
for ( auto i = 0; i < n_vocab; ++i ) {
if ( i < range[0] || i >= range[1] ) {
logits[i] = -INFINITY;
}
}
// sample the next token
auto t = llama_sampler_sample(smpl, ctx, -1);
// is stop token
if ( t == 9472 ) { // <|AR|0:STOP|>
break;
}
char buf[256];
llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
printf("%s\n", buf );
batch_add( batch, 0, n_embd, resps_embds[0], resps_tokens.size(), true );
resps_tokens.emplace_back(t);
}
printf("\n");
const auto t_main_end = ggml_time_us();
fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
fprintf(stderr, "\n");
llama_perf_sampler_print(smpl);
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);
return 0;
}

View File

@ -71,8 +71,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
"stt",
]
classifier_bias = False
embedding = torch.nn.Embedding( n_tokens, model_dim )
classifier = torch.nn.Linear( model_dim, n_tokens )
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
# to-do: ignore classifier for RVQ level 7
@ -81,6 +83,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_end = l_tokens[0]
embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight']
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
# tokenizer already has these tokens
@ -102,6 +105,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_end += l_tokens[2] // 2
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight']
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
@ -112,10 +116,11 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_end += l_tokens[2] // 2
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
if classifier_bias:
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<NAR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<NAR|0:0|STOP|>'] = token_start + 1024
tokenizer_vocab[f'<|NAR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<|NAR|0:0|STOP|>'] = token_start + 1024
# inject NAR
token_start = token_end
@ -125,6 +130,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
end = start + n_audio_tokens
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
if classifier_bias:
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
@ -142,6 +148,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_end += l_tokens[5]
embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight']
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
for t in range(n_len_tokens):
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
@ -183,6 +190,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
classifier_dict = classifier.state_dict()
model_dict['model.embed_tokens.weight'] = embedding_dict['weight']
model_dict['lm_head.weight'] = classifier_dict['weight']
if classifier_bias:
model_dict['lm_head.bias'] = classifier_dict['bias']
# write files in an HF compatible way

View File

@ -509,6 +509,7 @@ class TTS():
else:
raise Exception("!")
# to-do: care about batching later
resps = resps_list[0]