diff --git a/data/tokenizer.json b/data/tokenizer.json index 8df396c..b2a344a 100644 --- a/data/tokenizer.json +++ b/data/tokenizer.json @@ -112,7 +112,7 @@ "continuing_subword_prefix": null, "end_of_word_suffix": null, "fuse_unk": false, - "byte_fallback": true, + "byte_fallback": false, "ignore_merges": false, "vocab": { "": 0, diff --git a/vall_e.cpp/README.md b/vall_e.cpp/README.md index 7035f8e..8478f0e 100644 --- a/vall_e.cpp/README.md +++ b/vall_e.cpp/README.md @@ -9,13 +9,24 @@ At the moment it's ***very*** barebones as I try and wrestle with `llama.cpp`'s Populate `./include/` with the `llama.cpp` and `encodec.cpp` headers. Populate `./libs/` with the compiled libraries of `llama.cpp` and `encodec.cpp`. -* `encodec.cpp` requires updating `ggml` to the latest version and doing a quick hack to make it work on the CPU backend. -* `llama.cpp` currently requires no hacks, but: - * would be *very* nice to retrieve a model's `tok_embd` through the API. - * would be ***very*** nice to only specify a slice of the output head through the API. Run `make`. + +### Required Modifications + +`encodec.cpp` requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working. +`llama.cpp` *might* not require any modifications, but: +* `llm.build_vall_e` can mostly copy `llm.build_llama`, but with: + * `KQ_mask = build_inp_KQ_mask( lctx.cparams.causal_attn )` + * a unified output head (pain) + * OR adjusting the `model.output` to the correct classifier head + * OR slicing that tensor with the right range (`ggml_view_2d` confuses me) + * both require also require `*const_cast(&ctx->model.hparams.n_vocab) = output->ne[1];` because the logits are tied to `n_vocab` +* commenting out `GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());` because grabbing embeddings/classifiers require using `bid` to trick it thinking it's part of a layer +* some helper functions to retrieve the embeddings tensor from the model +* some helper functions to set the target classifier head + ## To-Do * [x] converted model to GGUF diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 5716e78..0c8dae8 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -11,9 +11,30 @@ #include #include #include +#include #include -#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd +#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions +#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch + +#if !LLAMA_CPP_EXTENDED + #include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd +#endif + +std::vector read_2d_tensor( struct ggml_tensor* tensor ) { + size_t size = tensor->ne[0] * tensor->ne[1]; + std::vector res( size ); + + auto* qtype = ggml_get_type_traits(tensor->type); + // dequantize if needed + if ( ggml_is_quantized(tensor->type) ) { + qtype->to_float(tensor->data, res.data(), res.size()); + } else { + memcpy( res.data(), tensor->data, res.size() * sizeof(float) ); + } + + return res; +} // stores the raw inputs to be fed struct input_t { @@ -26,94 +47,159 @@ struct input_t { std::vector> prom = {}; std::vector> resp = {}; }; + +/* +[(0, 256), 'text_emb.weight', 'classifiers.proj.9.weight', None], +[(256, 264), 'rvq_l_emb.weight', None, '<|RVQ:{l}|>'], +[(264, 270), 'langs_emb.weight', None, '<|lang:{lang}|>'], +[(270, 279), 'tasks_emb.weight', None, '<|task:{task}|>'], +[(279, 290), 'len_emb.weight', 'classifiers.proj.10.weight', '<|len:{id}|>'], +[(290, 291), 'tones_emb.weight', None, '<|tone:{tone}|>'], +[(291, 292), 'sep', None, '<|sep|>'], +[(292, 1316), 'proms_emb.embeddings.0.weight', None, '<|P|0|{id}|>'], +[(1316, 2340), 'proms_emb.embeddings.1.weight', None, '<|P|1|{id}|>'], +[(2340, 3364), 'proms_emb.embeddings.2.weight', None, '<|P|2|{id}|>'], +[(3364, 4388), 'proms_emb.embeddings.3.weight', None, '<|P|3|{id}|>'], +[(4388, 5412), 'proms_emb.embeddings.4.weight', None, '<|P|4|{id}|>'], +[(5412, 6436), 'proms_emb.embeddings.5.weight', None, '<|P|5|{id}|>'], +[(6436, 7460), 'proms_emb.embeddings.6.weight', None, '<|P|6|{id}|>'], +[(7460, 8484), 'proms_emb.embeddings.7.weight', None, '<|P|7|{id}|>'], +[(8484, 9509), 'resps_emb.embeddings.0.weight', 'classifiers.proj.0.weight', '<|R|AR|0:0|{id}|>'], +[(9509, 10533), 'resps_emb.embeddings.1.weight', 'classifiers.proj.1.weight', '<|R|NAR|0:1|{id}|>'], +[(10533, 11557), 'resps_emb.embeddings.2.weight', 'classifiers.proj.2.weight', '<|R|NAR|1:2|{id}|>'], +[(11557, 12581), 'resps_emb.embeddings.3.weight', 'classifiers.proj.3.weight', '<|R|NAR|2:3|{id}|>'], +[(12581, 13605), 'resps_emb.embeddings.4.weight', 'classifiers.proj.4.weight', '<|R|NAR|3:4|{id}|>'], +[(13605, 14629), 'resps_emb.embeddings.5.weight', 'classifiers.proj.5.weight', '<|R|NAR|4:5|{id}|>'], +[(14629, 15653), 'resps_emb.embeddings.6.weight', 'classifiers.proj.6.weight', '<|R|NAR|5:6|{id}|>'], +[(15653, 16677), 'resps_emb.embeddings.7.weight', 'classifiers.proj.7.weight', '<|R|NAR|6:7|{id}|>'], +[(16677, 17702), 'resps_emb.embeddings.8.weight', 'classifiers.proj.8.weight', '<|R|NAR|0:0|{id}|>'] +*/ + // handles all the cringe logic of slicing embeddings +struct ranges_t { + std::string name; + + uint32_t start; + uint32_t end; + + int32_t classifier_idx = -1; +}; +ranges_t io_ranges[] = { + { "text", 0, 256, 9, }, + { "rvq_l", 256, 264, -1, }, + { "lang", 264, 270, -1, }, + { "task", 270, 279, -1, }, + { "len", 279, 290, 10, }, + { "tone", 290, 291, -1, }, + { "sep", 291, 292, -1, }, + + { "prom|0", 292, 1316, -1, }, + { "prom|1", 1316, 2340, -1, }, + { "prom|2", 2340, 3364, -1, }, + { "prom|3", 3364, 4388, -1, }, + { "prom|4", 4388, 5412, -1, }, + { "prom|5", 5412, 6436, -1, }, + { "prom|6", 6436, 7460, -1, }, + { "prom|7", 7460, 8484, -1, }, + + { "resps|AR:0 8484, 9509, 0,:0", }, + { "resps|NAR:0 9509, 10533, 1,:1", }, + { "resps|NAR:1: 10533, 11557, 2,2", }, + { "resps|NAR:2: 11557, 12581, 3,3", }, + { "resps|NAR:3: 12581, 13605, 4,4", }, + { "resps|NAR:4: 13605, 14629, 5,5", }, + { "resps|NAR:5: 14629, 15653, 6,6", }, + { "resps|NAR:6: 15653, 16677, 7,7", }, + { "resps|NAR:0: 16677, 17702, 8,0", }, +}; + struct embeddings_t { + int n_embd; + int n_vocab; + + ranges_t range; + std::vector embds; +}; +struct embeddings_map_t { int n_embd = 0; int n_vocab = 0; - float* embds = NULL; + + // mapping + std::unordered_map mapped_embeddings; - int text_embd_start = 0; // - int rvq_level_embd_start = 17666; // <|RVQ:0> - int len_embd_start = 17674; // <|len:0|> - int lang_embd_start = 17686; // <|lang:en|> - int task_embd_start = 17692; // <|task:tts|> - int sep_embd_start = 17685; // <|sep|> - int prom_embd_start[8] = { - 256 + (1024 * 0), // <|P|0:0|> - 256 + (1024 * 1), // <|P|1:0|> - 256 + (1024 * 2), // <|P|2:0|> - 256 + (1024 * 3), // <|P|3:0|> - 256 + (1024 * 4), // <|P|4:0|> - 256 + (1024 * 5), // <|P|5:0|> - 256 + (1024 * 6), // <|P|6:0|> - 256 + (1024 * 7), // <|P|7:0|> - }; - int resp_embd_start[9] = { - 8448, // <|AR|0:0|> - 9473, // <|NAR|0:0|> - 10498 + (1024 * 0), // <|NAR|0:1|> - 10498 + (1024 * 1), // <|NAR|1:2|> - 10498 + (1024 * 2), // <|NAR|2:3|> - 10498 + (1024 * 3), // <|NAR|3:4|> - 10498 + (1024 * 4), // <|NAR|4:5|> - 10498 + (1024 * 5), // <|NAR|5:6|> - 10498 + (1024 * 6), // <|NAR|6:7|> - }; - - float* text_embds = NULL; // &embds[text_embd_start * n_embd]; - float* rvq_level_embd = NULL; // &embds[rvq_level_embd_start * n_embd]; - float* len_embd = NULL; // &embds[len_embd_start * n_embd]; - float* lang_embd = NULL; // &embds[lang_embd_start * n_embd]; - float* task_embd = NULL; // &embds[task_embd_start * n_embd]; - float* sep_embd = NULL; // &embds[sep_embd_start * n_embd]; - - float* prom_embds[8] = { - NULL, // &embds[prom_embd_start[0] * n_embd], - NULL, // &embds[prom_embd_start[1] * n_embd], - NULL, // &embds[prom_embd_start[2] * n_embd], - NULL, // &embds[prom_embd_start[3] * n_embd], - NULL, // &embds[prom_embd_start[4] * n_embd], - NULL, // &embds[prom_embd_start[5] * n_embd], - NULL, // &embds[prom_embd_start[6] * n_embd], - NULL, // &embds[prom_embd_start[7] * n_embd], - }; - float* resps_embds[9] = { - NULL, // &embds[resp_embd_start[0] * n_embd], - NULL, // &embds[resp_embd_start[1] * n_embd], - NULL, // &embds[resp_embd_start[2] * n_embd], - NULL, // &embds[resp_embd_start[3] * n_embd], - NULL, // &embds[resp_embd_start[4] * n_embd], - NULL, // &embds[resp_embd_start[5] * n_embd], - NULL, // &embds[resp_embd_start[6] * n_embd], - NULL, // &embds[resp_embd_start[7] * n_embd], - NULL, // &embds[resp_embd_start[8] * n_embd], - }; - - embeddings_t( int n_embd = 0, int n_vocab = 0, float* embds = NULL ) { - init( n_embd, n_vocab, embds ); + const embeddings_t& get_embeddings( const std::string& name ) { + return mapped_embeddings[name]; + } + const float* get_embeddings_p( const std::string& name ) { + return mapped_embeddings[name].embds.data(); } - void init( int n_embd, int n_vocab, float* embds = NULL ) { - if ( !n_embd || !n_vocab || !embds ) return; + void init( llama_model* model ) { + this->n_embd = llama_n_embd( model ); + this->n_vocab = llama_n_vocab( model ); - this->n_embd = n_embd; - this->n_vocab = n_vocab; - this->embds = embds; + // to-do: figure a nicer way to do this + #if LLAMA_CPP_USE_VALL_E_ARCH + mapped_embeddings["text"] = { n_embd, 0, { "text", 0, 0, 9, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 0)) }; + mapped_embeddings["rvq_l"] = { n_embd, 0, { "rvq_l", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 1)) }; + mapped_embeddings["langs"] = { n_embd, 0, { "langs", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 2)) }; + mapped_embeddings["tasks"] = { n_embd, 0, { "tasks", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 3)) }; + mapped_embeddings["len"] = { n_embd, 0, { "len", 0, 0, 10, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 4)) }; + mapped_embeddings["tones"] = { n_embd, 0, { "tones", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 5)) }; + mapped_embeddings["sep"] = { n_embd, 0, { "sep", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 6)) }; - text_embds = &embds[text_embd_start * n_embd]; - rvq_level_embd = &embds[rvq_level_embd_start * n_embd]; - len_embd = &embds[len_embd_start * n_embd]; - lang_embd = &embds[lang_embd_start * n_embd]; - task_embd = &embds[task_embd_start * n_embd]; - sep_embd = &embds[sep_embd_start * n_embd]; + mapped_embeddings["prom|0"] = { n_embd, 0, { "prom|0", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 0)) }; + mapped_embeddings["prom|1"] = { n_embd, 0, { "prom|1", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 1)) }; + mapped_embeddings["prom|2"] = { n_embd, 0, { "prom|2", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 2)) }; + mapped_embeddings["prom|3"] = { n_embd, 0, { "prom|3", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 3)) }; + mapped_embeddings["prom|4"] = { n_embd, 0, { "prom|4", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 4)) }; + mapped_embeddings["prom|5"] = { n_embd, 0, { "prom|5", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 5)) }; + mapped_embeddings["prom|6"] = { n_embd, 0, { "prom|6", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 6)) }; + mapped_embeddings["prom|7"] = { n_embd, 0, { "prom|7", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 7)) }; + + mapped_embeddings["resps|AR:0:0"] = { n_embd, 0, { "resps|AR:0:0", 0, 0, 0, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 0)) }; + mapped_embeddings["resps|NAR:0:1"] = { n_embd, 0, { "resps|NAR:0:1", 0, 0, 1, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 1)) }; + mapped_embeddings["resps|NAR:1:2"] = { n_embd, 0, { "resps|NAR:1:2", 0, 0, 2, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 2)) }; + mapped_embeddings["resps|NAR:2:3"] = { n_embd, 0, { "resps|NAR:2:3", 0, 0, 3, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 3)) }; + mapped_embeddings["resps|NAR:3:4"] = { n_embd, 0, { "resps|NAR:3:4", 0, 0, 4, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 4)) }; + mapped_embeddings["resps|NAR:4:5"] = { n_embd, 0, { "resps|NAR:4:5", 0, 0, 5, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 5)) }; + mapped_embeddings["resps|NAR:5:6"] = { n_embd, 0, { "resps|NAR:5:6", 0, 0, 6, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 6)) }; + mapped_embeddings["resps|NAR:6:7"] = { n_embd, 0, { "resps|NAR:6:7", 0, 0, 7, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 7)) }; + mapped_embeddings["resps|NAR:0:0"] = { n_embd, 0, { "resps|NAR:0:0", 0, 0, 8, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 8)) }; - for ( auto i = 0; i < 8; ++i ) prom_embds[i] = &embds[prom_embd_start[i] * n_embd]; - for ( auto i = 0; i < 9; ++i ) resps_embds[i] = &embds[resp_embd_start[i] * n_embd]; + // update values + for ( auto& pair : mapped_embeddings ) { + auto& k = pair.first; + auto& v = pair.second; + auto& embds = v.embds; + + v.n_vocab = embds.size() / n_embd; + v.range.end = v.n_vocab; + } + #else + + #if LLAMA_CPP_EXTENDED + auto* tensor = llama_get_embedding_weights( model ); + #else + auto* tensor = model->tok_embd; + #endif + + // prepare slices + std::vector raw_embeddings = read_2d_tensor( tensor ); + for ( auto& range : io_ranges ) { + mapped_embeddings[range.name] = { + n_embd, + range.end - range.start, + range, + std::vector( raw_embeddings.data() + range.start, raw_embeddings.data() + range.end ) + }; + } + #endif } }; // maps embeddings easily -std::vector> map_embeddings( const std::vector& tokens, int n_embd, float* embds ) { +std::vector> map_embeddings( const std::vector& tokens, int n_embd, const float* embds ) { std::vector> embedded( tokens.size() ); for ( auto i = 0; i < tokens.size(); ++i ) { embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) ); @@ -123,7 +209,7 @@ std::vector> map_embeddings( const std::vector& // handles adding either a token OR the embedding of that token into the batch // this really, really helps avoid needing to abuse the tokenizer -void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool output, const std::vector & seq_ids = {0} ) { +void batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector & seq_ids = {0} ) { GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); // insert raw embedding instead @@ -143,7 +229,7 @@ void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, ll batch.logits[batch.n_tokens] = output ? 1 : 0; // printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, batch.pos[batch.n_tokens], &batch.embd[batch.n_tokens], batch.logits[batch.n_tokens] ); - // printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens-1, id, pos, embds, output ); + // printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, pos, embds, output ); batch.n_tokens++; } @@ -265,7 +351,7 @@ const int MAX_DURATION = 75; // * 12; const int CTX_SIZE = 2048; // sums embeddings over a 2D "tensor" -std::vector> sum_embeddings( const std::vector>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) { +std::vector> sum_embeddings( const std::vector>& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM ) { std::vector> res( input.size() ); res.resize( input[0].size() ); for ( auto& e : res ) e.resize( n_embd ); @@ -274,9 +360,9 @@ std::vector> sum_embeddings( const std::vector> sum_embeddings( const std::vector generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) { +std::vector generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_map_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) { llama_batch batch = llama_batch_init( 22500, embeddings_map.n_embd, 22500 ); // Decoding loop @@ -364,34 +479,50 @@ std::vector generate( llama_context* ctx, llama_model* model, llama return {}; } - float* embds = NULL; - int logit_range[2]; - if ( mode == INFERENCE_MODE_AR ) { - logit_range[0] = embeddings_map.resp_embd_start[0]; - logit_range[1] = embeddings_map.resp_embd_start[1]; - - embds = embeddings_map.resps_embds[0]; - - stop_token = embeddings_map.resp_embd_start[1] - 1; // <|AR|0:STOP|> - } else if ( mode == INFERENCE_MODE_NAR_DEMASK ) { - logit_range[0] = embeddings_map.resp_embd_start[1]; - logit_range[1] = embeddings_map.resp_embd_start[2]; - - embds = embeddings_map.resps_embds[1]; + const float* embds = NULL; + ranges_t range; - stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|> + if ( mode == INFERENCE_MODE_AR ) { + auto& embeddings = embeddings_map.get_embeddings("resps|AR:0:0"); + range = embeddings.range; + embds = embeddings.embds.data(); + stop_token = range.end - 1; + + // llama_set_classifier_index( ctx, 0 ); + + printf("Generating in %s mode (%i:%i)\n", "AR", range.start, range.end); } else if ( mode == INFERENCE_MODE_NAR ) { - logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l-1]; - logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l-1]; + std::string k_embds[] = { + "resps|NAR:0:0", // invalid + "resps|NAR:0:1", + "resps|NAR:1:2", + "resps|NAR:2:3", + "resps|NAR:3:4", + "resps|NAR:4:5", + "resps|NAR:5:6", + "resps|NAR:6:7", + }; + auto& embeddings = embeddings_map.get_embeddings(k_embds[rvq_l]); + range = embeddings.range; + embds = embeddings.embds.data(); - embds = embeddings_map.resps_embds[2]; + // llama_set_classifier_index( ctx, rvq_l ); + printf("Generating in %s mode (%i:%i)\n", "NAR", range.start, range.end); } else if ( mode == INFERENCE_MODE_LEN ) { - logit_range[0] = embeddings_map.len_embd_start; - logit_range[1] = embeddings_map.len_embd_start + 11; + auto& embeddings = embeddings_map.get_embeddings("len"); + range = embeddings.range; + embds = embeddings.embds.data(); + stop_token = range.end - 1; - embds = embeddings_map.len_embd; - - stop_token = embeddings_map.len_embd_start + 10; + // llama_set_classifier_index( ctx, 10 ); + printf("Generating in %s mode (%i:%i)\n", "len", range.start, range.end); + } else if ( mode == INFERENCE_MODE_NAR_DEMASK ) { + auto& embeddings = embeddings_map.get_embeddings("NAR:0:0"); + range = embeddings.range; + embds = embeddings.embds.data(); + + // llama_set_classifier_index( ctx, 8 ); + printf("Generating in %s mode (%i:%i)\n", "NAR-len", range.start, range.end); } llama_set_causal_attn( ctx, n_logits == 1 ); @@ -408,32 +539,34 @@ std::vector generate( llama_context* ctx, llama_model* model, llama for ( auto i = n_logits; i > 0; --i ) { // filter logits auto* logits = llama_get_logits_ith( ctx, -i ); + + // ensures only tokens within our designated range are used + #if !LLAMA_CPP_USE_VALL_E_ARCH for ( auto i = 0; i < embeddings_map.n_vocab; ++i ) { - // out of target logit range, set to never happen - if ( i < logit_range[0] || i >= logit_range[1] ) logits[i] = -INFINITY; + if ( i < range.start || i >= range.end ) logits[i] = -INFINITY; } + #endif // sample the next token auto t = llama_sampler_sample(smpl, ctx, -i); - if ( verbose ) { - // print token for debugging - char buf[256]; - int n = llama_token_to_piece( model, t, buf, sizeof(buf), 0, true ); - if ( n < 256 ) buf[n] = '\0'; - printf("%s\n", buf ); - } + // offset back into range + #if !LLAMA_CPP_USE_VALL_E_ARCH + t -= range.start; + #endif + + printf("%i: %i\n", n_decode, t); // is stop token if ( t == stop_token ) { + printf("STOPPED\n"); max_tokens = 0; break; } - // offset into range - t -= logit_range[0]; - + // store token output_tokens.emplace_back(t); + // update batch with token batch_add( batch, t, embeddings_map.n_embd, embds, output_tokens.size(), true ); } } @@ -460,7 +593,7 @@ int main(int argc, char ** argv) { int32_t ngl = 0; int modality = MODALITY_NAR_LEN; input_t input{}; - embeddings_t embeddings_map{}; + embeddings_map_t embeddings_map{}; // input.phonemes = "hˈɛloː ʋˈɔrlt"; input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // hˈɛloː ʋˈɔrlt @@ -473,19 +606,6 @@ int main(int argc, char ** argv) { // load dynamic backends ggml_backend_load_all(); - struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl); - if (!ectx) { - printf("%s: error during loading model\n", __func__); - return 1; - } - - encodec_set_target_bandwidth(ectx, 6); - encodec_set_sample_rate(ectx, 24000); - - // load wavform - input.prom = encode_audio_from_disk(ectx, input_prompt_path); - //input.resp = encode_audio_from_disk(ectx, output_response_path); - // initialize the models llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; @@ -524,19 +644,25 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy()); + struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl); + if (!ectx) { + printf("%s: error during loading model\n", __func__); + return 1; + } + + encodec_set_target_bandwidth(ectx, 6); + encodec_set_sample_rate(ectx, 24000); + + // load wavform + input.prom = encode_audio_from_disk(ectx, input_prompt_path); + //input.resp = encode_audio_from_disk(ectx, output_response_path); + // prepare batch auto n_embd = llama_n_embd( model ); auto n_vocab = llama_n_vocab( model ); // grab input embeddings - std::vector embds( n_embd * n_vocab ); - auto* qtype = ggml_get_type_traits(model->tok_embd->type); - // dequantize if needed - if ( ggml_is_quantized(model->tok_embd->type) ) { - qtype->to_float(model->tok_embd->data, embds.data(), embds.size()); - } - // update mapping - embeddings_map.init( n_embd, n_vocab, embds.data() ); + embeddings_map.init( model ); // tokenize phonemes // to-do: make this work, the vocab does not work @@ -573,7 +699,7 @@ int main(int argc, char ** argv) { // fill with mask tokens input.resp.resize(1); for ( auto i = 0; i < len; ++i ) { - input.resp[0].emplace_back( embeddings_map.resp_embd_start[3] - 1 ); // fill with masked tokens + input.resp[0].emplace_back( 1024 ); // fill with masked tokens } // inference NAR-len 0 diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index b5c2b38..5ef08a0 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -109,6 +109,8 @@ def load_engines(training=True, **model_kwargs): optimizer_class = None scheduler_class = None + model.config.frozen_params = ['sep', 'dropout_token', 'text_emb.weight', 'proms_emb.embeddings.0.weight', 'proms_emb.embeddings.1.weight', 'proms_emb.embeddings.2.weight', 'proms_emb.embeddings.3.weight', 'proms_emb.embeddings.4.weight', 'proms_emb.embeddings.5.weight', 'proms_emb.embeddings.6.weight', 'proms_emb.embeddings.7.weight', 'resps_emb.embeddings.0.weight', 'resps_emb.embeddings.1.weight', 'resps_emb.embeddings.2.weight', 'resps_emb.embeddings.3.weight', 'resps_emb.embeddings.4.weight', 'resps_emb.embeddings.5.weight', 'resps_emb.embeddings.6.weight', 'resps_emb.embeddings.7.weight', 'resps_emb.embeddings.8.weight', 'langs_emb.weight', 'tasks_emb.weight', 'tones_emb.weight', 'rvq_l_emb.weight', 'len_emb.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.norm.weight'] + params = { "params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], "lr": cfg.hyperparameters.learning_rate, diff --git a/vall_e/export.py b/vall_e/export.py index d740316..ab0d03c 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -12,7 +12,7 @@ from .utils.io import torch_save, torch_load, json_read, json_write, Path # stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight # *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed @torch.no_grad() -def convert_to_hf( state_dict, config = None, save_path = None ): +def convert_to_hf_llama( state_dict, config = None, save_path = None ): n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0] @@ -21,6 +21,9 @@ def convert_to_hf( state_dict, config = None, save_path = None ): n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0] n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0] + classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias + split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers + # the new tokenizer to use tokenizer = {} tokenizer_vocab = {} @@ -37,20 +40,6 @@ def convert_to_hf( state_dict, config = None, save_path = None ): } } - l_tokens = [ - n_text_tokens, # text - n_audio_tokens * n_resp_levels, # prom - (n_audio_tokens + 1) * 2, # resp: AR + NAR-len (with stop/mask) - (n_audio_tokens) * (n_resp_levels - 1), # NAR - n_resp_levels, # RVQ level - n_len_tokens, # len tokens - 1, # separator - n_lang_tokens, # langs - n_task_tokens, # tasks - ] - - n_tokens = sum(l_tokens) - lang_map = [ "en", "ja", @@ -70,9 +59,47 @@ def convert_to_hf( state_dict, config = None, save_path = None ): "eoe", "stt", ] + tone_map = [ + "neutral", + ] - classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias - split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers + # (start, end), embedding, classifier, token_format + mapping = [ + [(0, 0), "text_emb.weight", "classifiers.proj.9.weight", None], + [(0, 0), "rvq_l_emb.weight", None, "<|RVQ:{l}|>"], + [(0, 0), "langs_emb.weight", None, "<|lang:{lang}|>"], + [(0, 0), "tasks_emb.weight", None, "<|task:{task}|>"], + [(0, 0), "len_emb.weight", "classifiers.proj.10.weight", "<|len:{id}|>"], + [(0, 0), "tones_emb.weight", None, "<|tone:{tone}|>"], + [(0, 0), "sep", None, "<|sep|>"], + + [(0, 0), "proms_emb.embeddings.0.weight", None, "<|P|0|{id}|>"], + [(0, 0), "proms_emb.embeddings.1.weight", None, "<|P|1|{id}|>"], + [(0, 0), "proms_emb.embeddings.2.weight", None, "<|P|2|{id}|>"], + [(0, 0), "proms_emb.embeddings.3.weight", None, "<|P|3|{id}|>"], + [(0, 0), "proms_emb.embeddings.4.weight", None, "<|P|4|{id}|>"], + [(0, 0), "proms_emb.embeddings.5.weight", None, "<|P|5|{id}|>"], + [(0, 0), "proms_emb.embeddings.6.weight", None, "<|P|6|{id}|>"], + [(0, 0), "proms_emb.embeddings.7.weight", None, "<|P|7|{id}|>"], + + [(0, 0), "resps_emb.embeddings.0.weight", "classifiers.proj.0.weight", "<|R|AR|0:0|{id}|>"], + [(0, 0), "resps_emb.embeddings.1.weight", "classifiers.proj.1.weight", "<|R|NAR|0:1|{id}|>"], + [(0, 0), "resps_emb.embeddings.2.weight", "classifiers.proj.2.weight", "<|R|NAR|1:2|{id}|>"], + [(0, 0), "resps_emb.embeddings.3.weight", "classifiers.proj.3.weight", "<|R|NAR|2:3|{id}|>"], + [(0, 0), "resps_emb.embeddings.4.weight", "classifiers.proj.4.weight", "<|R|NAR|3:4|{id}|>"], + [(0, 0), "resps_emb.embeddings.5.weight", "classifiers.proj.5.weight", "<|R|NAR|4:5|{id}|>"], + [(0, 0), "resps_emb.embeddings.6.weight", "classifiers.proj.6.weight", "<|R|NAR|5:6|{id}|>"], + [(0, 0), "resps_emb.embeddings.7.weight", "classifiers.proj.7.weight", "<|R|NAR|6:7|{id}|>"], + [(0, 0), "resps_emb.embeddings.8.weight", "classifiers.proj.8.weight", "<|R|NAR|0:0|{id}|>"], + ] + + n_tokens = 0 + # to-do: figure out discrepancy + for i, m in enumerate( mapping ): + k_embd = mapping[i][1] + embds = state_dict['module'][k_embd] if k_embd in state_dict['module'] else None + + n_tokens += 1 if embds.dim() == 1 else embds.shape[0] embedding = torch.nn.Embedding( n_tokens, model_dim ) classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias ) @@ -80,113 +107,49 @@ def convert_to_hf( state_dict, config = None, save_path = None ): if not split_classifiers: classifier.weight[:] = state_dict['module']['classifier.weight'][:] - # to-do: ignore classifier for RVQ level 7 + # update ranges + start = 0 + for i, m in enumerate( mapping ): + # get previous start + k_embd = mapping[i][1] + k_head = mapping[i][2] + token_format = mapping[i][3] - # inject text tokens - token_start = 0 - token_end = l_tokens[0] - embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight'] - if split_classifiers: - classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight'] - if classifier_bias: - classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias'] - # tokenizer already has these tokens + embds = state_dict['module'][k_embd] if k_embd in state_dict['module'] else None + head = state_dict['module'][k_head] if k_head in state_dict['module'] else None - # inject prom tokens - token_start = token_end - token_end += l_tokens[1] - for l in range(n_resp_levels): - start = token_start + (l * n_audio_tokens) - end = start + n_audio_tokens - embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight'] - # there's no corresponding classifier - #classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] - #classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] - for t in range(n_audio_tokens): - tokenizer_vocab[f'<|P|{l}:{t}|>'] = start + t + # expand if 1D + if embds.dim() == 1: + embds = embds.unsqueeze(0) - # inject AR - token_start = token_end - token_end += l_tokens[2] // 2 - embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight'] - if split_classifiers: - classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight'] - if classifier_bias: - classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias'] - for t in range(n_audio_tokens): - tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t - tokenizer_vocab[f''] = token_start + 1024 + tokens = embds.shape[0] - # inject NAR-len - token_start = token_end - token_end += l_tokens[2] // 2 - embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight'] - if split_classifiers: - classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight'] - if classifier_bias: - classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias'] - for t in range(n_audio_tokens): - tokenizer_vocab[f'<|NAR|0:0|{t}|>'] = token_start + t - tokenizer_vocab[f'<|NAR|0:0|STOP|>'] = token_start + 1024 - - # inject NAR - token_start = token_end - token_end += l_tokens[3] - for l in range(1, n_resp_levels): - start = token_start + ((l-1) * n_audio_tokens) - end = start + n_audio_tokens - embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight'] - if split_classifiers: - classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] - if classifier_bias: - classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] - for t in range(n_audio_tokens): - tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t - - # inject RVQ level - token_start = token_end - token_end += l_tokens[4] - embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight'] - # there is no corresponding classifier - for l in range(n_resp_levels): - tokenizer_vocab[f'<|RVQ:{l}|>'] = token_start + l - - # inject len - token_start = token_end - token_end += l_tokens[5] - embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight'] - if split_classifiers: - classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256 - if classifier_bias: - classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256 - for t in range(n_len_tokens): - tokenizer_vocab[f'<|len:{t}|>'] = token_start + t - - # inject sep - token_start = token_end - token_end += l_tokens[6] - embedding.weight[token_start:token_end] = state_dict['module']['sep'] - tokenizer_vocab['<|sep|>'] = token_start - # there is no corresponding classifier - - # inject langs - token_start = token_end - token_end += l_tokens[7] - embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight'] - for l in range(n_lang_tokens): - lang = lang_map[l] - tokenizer_vocab[f'<|lang:{lang}|>'] = token_start + l - # there is no corresponding classifier - - # inject tasks - token_start = token_end - token_end += l_tokens[8] - embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight'] - for l in range(n_task_tokens): - task = task_map[l] - tokenizer_vocab[f'<|task:{task}|>'] = token_start + l - # there is no corresponding classifier + if embds is not None: + embedding.weight[start:start+tokens] = embds + if split_classifiers and head is not None: + classifier.weight[start:start+head.shape[0]] = head + + if token_format is not None: + for idx in range(0, tokens): + # RVQ level + if "{l}" in token_format: + token = token_format.format(l=idx) + elif "{lang}" in token_format: + token = token_format.format(lang=lang_map[idx]) + elif "{task}" in token_format: + token = token_format.format(task=task_map[idx]) + elif "{tone}" in token_format: + token = token_format.format(tone=tone_map[idx]) + elif "{id}" in token_format: + token = token_format.format(id=idx) + else: + token = token_format + tokenizer_vocab[token] = idx + start + + end = start + tokens + mapping[i][0] = (start, end) + start = end model_dict = {} # filter out the underlying model weights and extract them @@ -225,7 +188,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): # write config.json json_write({ "architectures": [ - "LlamaForCausalLM" + "LLaMAForCausalLM" ], "attention_bias": False, "attention_dropout": 0.0, @@ -251,6 +214,130 @@ def convert_to_hf( state_dict, config = None, save_path = None ): "vocab_size": n_tokens }, out_dir / "config.json", pretty=True ) + return state_dict + +# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight +# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed +@torch.no_grad() +def convert_to_hf_custom( state_dict, config = None, save_path = None ): + n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape + + n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0] + n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0] + n_len_tokens = 11 + n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0] + n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0] + + classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias + split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers + + # the new tokenizer to use + tokenizer = {} + tokenizer_vocab = {} + + tokenizer_path = cfg.rel_path / cfg.tokenizer_path + if not tokenizer_path.exists(): + tokenizer_path = Path("./data/") / cfg.tokenizer_path + if tokenizer_path.exists(): + tokenizer = json_read( tokenizer_path ) + else: + tokenizer = { + "model": { + "vocab": get_phone_symmap() + } + } + + lang_map = [ + "en", + "ja", + "de", + "fr", + "zh", + "ko", + ] + task_map = [ + "tts", + "tts-c", + "ns", + "sr", + "tse", + "soe", + "mask", + "eoe", + "stt", + ] + + model_dict = {} + # filter out the underlying model weights and extract them + for k in state_dict['module'].keys(): + if not k.startswith('model.'): + continue + model_dict[k] = state_dict['module'][k].clone() + + # cringe + for l in range(11): + model_dict[f'classifiers.{l}.weight'] = state_dict['module'][f'classifiers.proj.{l}.weight'] + for l in range(8): + model_dict[f"embeddings.proms.{l}.weight"] = state_dict['module'][f"proms_emb.embeddings.{l}.weight"] + for l in range(9): + model_dict[f"embeddings.resps.{l}.weight"] = state_dict['module'][f"resps_emb.embeddings.{l}.weight"] + + model_dict["embeddings.aux.0.weight"] = state_dict['module']["text_emb.weight"] + model_dict["embeddings.aux.1.weight"] = state_dict['module']["rvq_l_emb.weight"] + model_dict["embeddings.aux.2.weight"] = state_dict['module']["langs_emb.weight"] + model_dict["embeddings.aux.3.weight"] = state_dict['module']["tasks_emb.weight"] + model_dict["embeddings.aux.4.weight"] = state_dict['module']["len_emb.weight"] + model_dict["embeddings.aux.5.weight"] = state_dict['module']["tones_emb.weight"] + model_dict["embeddings.aux.6.weight"] = state_dict['module']["sep"].unsqueeze(0) + + # write files in an HF compatible way + out_dir = cfg.rel_path / "hf" + out_dir.mkdir(parents=True, exist_ok=True) + # write weights + torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" ) + # write tokenizer.json + tokenizer['model']['vocab'] |= tokenizer_vocab + json_write(tokenizer, out_dir / "tokenizer.json", pretty=True) + # write tokenizer_config.json + json_write({ + "added_tokens": tokenizer['added_tokens'], + "bos_token": "", + "eos_token": "", + "clean_up_tokenization_spaces": True, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "tokenizer_class": "PreTrainedTokenizerFast" + }, out_dir / "tokenizer_config.json", pretty=True) + # write config.json + json_write({ + "architectures": [ + "ValleLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_size": model_dim, + "initializer_range": 0.02, + "intermediate_size": model_dim * 4, + "max_position_embeddings": 75 * 60 * 5, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 12, + "num_key_value_heads": 16, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0", + "use_cache": False, + "vocab_size": 256 + }, out_dir / "config.json", pretty=True ) return state_dict @@ -325,6 +412,7 @@ def main(): parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--module-only", action='store_true') parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style + parser.add_argument("--hf-llama", action='store_true', default=None) # convert to HF-style llama model parser.add_argument("--export-lora", action='store_true', default=None) # exports LoRA parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads @@ -352,8 +440,10 @@ def main(): engines = load_engines(training=False) # to ignore loading optimizer state callback = None - if args.hf: - callback = convert_to_hf + if args.hf_llama: + callback = convert_to_hf_llama + elif args.hf: + callback = convert_to_hf_custom elif args.export_lora: callback = extract_lora elif args.split_classifiers: diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0d47257..9ac8fc5 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -58,33 +58,30 @@ task_outputs = { # yuck def _get_offsets(): return { - "text": 0, # - "quant_level": 17666, # <|RVQ:0> - "len": 17674, # <|len:0|> - "lang": 17686, # <|lang:en|>" - "task": 17692, # <|task:tts|> - "sep": 17685, # <|sep|> - "prom": [ - 256 + (1024 * 0), # <|P|0:0|> - 256 + (1024 * 1), # <|P|1:0|> - 256 + (1024 * 2), # <|P|2:0|> - 256 + (1024 * 3), # <|P|3:0|> - 256 + (1024 * 4), # <|P|4:0|> - 256 + (1024 * 5), # <|P|5:0|> - 256 + (1024 * 6), # <|P|6:0|> - 256 + (1024 * 7), # <|P|7:0|> - ], - "resp": [ - 8448, # <|AR|0:0|> - 9473, # <|NAR|0:0|> - 10498 + (1024 * 0), # <|NAR|0:1|> - 10498 + (1024 * 1), # <|NAR|1:2|> - 10498 + (1024 * 2), # <|NAR|2:3|> - 10498 + (1024 * 3), # <|NAR|3:4|> - 10498 + (1024 * 4), # <|NAR|4:5|> - 10498 + (1024 * 5), # <|NAR|5:6|> - 10498 + (1024 * 6), # <|NAR|6:7|> - ] + "text": (0, 256), + "quant_level": (256, 264), + "lang": (264, 270), + "task": (270, 279), + "len": (279, 290), + "tone": (290, 291), + "sep": (291, 292), + "prom|0": (292, 1316), + "prom|1": (1316, 2340), + "prom|2": (2340, 3364), + "prom|3": (3364, 4388), + "prom|4": (4388, 5412), + "prom|5": (5412, 6436), + "prom|6": (6436, 7460), + "prom|7": (7460, 8484), + "resps|AR:0:0": (8484, 9509), + "resps|NAR:0:1": (9509, 10533), + "resps|NAR:1:2": (10533, 11557), + "resps|NAR:2:3": (11557, 12581), + "resps|NAR:3:4": (12581, 13605), + "resps|NAR:4:5": (13605, 14629), + "resps|NAR:5:6": (14629, 15653), + "resps|NAR:6:7": (15653, 16677), + "resps|NAR:0:0": (16677, 17702), } def _dropout_mask( input, p=None ): @@ -1084,27 +1081,22 @@ class Base(nn.Module): classifier_level = input for name, input in batch_input: - if name not in offsets: - continue - if not isinstance( input, torch.Tensor ): continue - offset = offsets[name] - if name in ["prom", "resp"]: - l = quant_level - if name == "resp": - if classifier_level == "AR:0:0": - l = 0 - elif classifier_level == "NAR:0:0": - l = 1 - else: - l = 2 + (quant_level-1) + k = name + if name == "prom": + k = f'prom|{quant_level}' + elif name == "resp": + k = f'resps|{classifier_level}' - offset = offset[l] + if k not in offsets: + continue + + start, end = offsets[k] for i, t in enumerate( input ): - input[i] += offset * direction + input[i] += start * direction return inputs @@ -1446,45 +1438,22 @@ class Base(nn.Module): # offset to flattened vocab ranges if self.classifier is not None: offsets = _get_offsets() - if name in offsets: - offset = offsets[name] - # yes there's a better way - if name == "prom": - offset = offset[quant_level] - elif name == "resp": - """ - if classifier_level == "AR:0:0": - offset = offset[0] - elif classifier_level == "NAR:0:0": - offset = offset[1] - elif classifier_level == "NAR:0:1": - offset = offset[2] - elif classifier_level == "NAR:1:2": - offset = offset[3] - elif classifier_level == "NAR:2:3": - offset = offset[4] - elif classifier_level == "NAR:3:4": - offset = offset[5] - elif classifier_level == "NAR:4:5": - offset = offset[6] - elif classifier_level == "NAR:5:6": - offset = offset[7] - elif classifier_level == "NAR:6:7": - offset = offset[8] - else: - continue - """ - if classifier_level == "AR:0:0": - offset = offset[0] - elif classifier_level == "NAR:0:0": - offset = offset[1] - else: - offset = offset[2 + (quant_level-1)] + + k = name + if name == "stt": + k = "text" + if name == "prom": + k = f'prom|{quant_level}' + elif name == "resp": + k = f'resps|{classifier_level}' + + if k in offsets: + start, end = offsets[k] for i, t in enumerate( token ): if t == self.ignore_index: continue - token[i] += offset + token[i] += start if token.is_floating_point(): ignored = True @@ -1709,17 +1678,7 @@ class Base(nn.Module): if hidden_states is not None: for i, state in enumerate( hidden_states ): hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] - - # corrections - """ - for batch_index, classifier_level in enumerate( classifier_levels ): - if classifier_level == "len" and logits[batch_index].shape[1] > 11: - logits[batch_index] = logits[batch_index][:,:11] - elif classifier_level == "NAR:0:0" and logits[batch_index].shape[1] > 1024: - logits[batch_index] = logits[batch_index][:,:1024] - """ - # compute loss if the target is given if not training: loss = None stats = None @@ -1731,32 +1690,20 @@ class Base(nn.Module): if self.classifier is not None: offsets = _get_offsets() for batch_index, classifier_level in enumerate( classifier_levels ): - # yes there's a better way - if classifier_level == "len": - offset = offsets["len"], 11 - elif classifier_level == "AR:0:0": - offset = offsets["resp"][0], 1025 - elif classifier_level == "NAR:0:0": - offset = offsets["resp"][1], 1024 - elif classifier_level == "NAR:0:1": - offset = offsets["resp"][2], 1024 - elif classifier_level == "NAR:1:2": - offset = offsets["resp"][3], 1024 - elif classifier_level == "NAR:2:3": - offset = offsets["resp"][4], 1024 - elif classifier_level == "NAR:3:4": - offset = offsets["resp"][5], 1024 - elif classifier_level == "NAR:4:5": - offset = offsets["resp"][6], 1024 - elif classifier_level == "NAR:5:6": - offset = offsets["resp"][7], 1024 - elif classifier_level == "NAR:6:7": - offset = offsets["resp"][8], 1024 - else: + k = name + if name == "prom": + k = f'prom:{quant_levels[batch_index]}' + elif name == "resp": + k = f'resps|{classifier_level}' + + if k not in offsets: continue - logits[batch_index] = logits[batch_index][offset[0]:offset[0]+offset[1], :] + start, end = offsets[k] + logits[batch_index] = logits[batch_index][start:start+end, :] + + # compute loss if the target is given else: loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )