sanity cleanup
This commit is contained in:
parent
353e478e68
commit
0d4329d2e3
|
@ -112,7 +112,7 @@
|
|||
"continuing_subword_prefix": null,
|
||||
"end_of_word_suffix": null,
|
||||
"fuse_unk": false,
|
||||
"byte_fallback": true,
|
||||
"byte_fallback": false,
|
||||
"ignore_merges": false,
|
||||
"vocab": {
|
||||
"<unk>": 0,
|
||||
|
|
|
@ -9,13 +9,24 @@ At the moment it's ***very*** barebones as I try and wrestle with `llama.cpp`'s
|
|||
Populate `./include/` with the `llama.cpp` and `encodec.cpp` headers.
|
||||
|
||||
Populate `./libs/` with the compiled libraries of `llama.cpp` and `encodec.cpp`.
|
||||
* `encodec.cpp` requires updating `ggml` to the latest version and doing a quick hack to make it work on the CPU backend.
|
||||
* `llama.cpp` currently requires no hacks, but:
|
||||
* would be *very* nice to retrieve a model's `tok_embd` through the API.
|
||||
* would be ***very*** nice to only specify a slice of the output head through the API.
|
||||
|
||||
Run `make`.
|
||||
|
||||
|
||||
### Required Modifications
|
||||
|
||||
`encodec.cpp` requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working.
|
||||
`llama.cpp` *might* not require any modifications, but:
|
||||
* `llm.build_vall_e` can mostly copy `llm.build_llama`, but with:
|
||||
* `KQ_mask = build_inp_KQ_mask( lctx.cparams.causal_attn )`
|
||||
* a unified output head (pain)
|
||||
* OR adjusting the `model.output` to the correct classifier head
|
||||
* OR slicing that tensor with the right range (`ggml_view_2d` confuses me)
|
||||
* both require also require `*const_cast<uint32_t*>(&ctx->model.hparams.n_vocab) = output->ne[1];` because the logits are tied to `n_vocab`
|
||||
* commenting out `GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());` because grabbing embeddings/classifiers require using `bid` to trick it thinking it's part of a layer
|
||||
* some helper functions to retrieve the embeddings tensor from the model
|
||||
* some helper functions to set the target classifier head
|
||||
|
||||
## To-Do
|
||||
|
||||
* [x] converted model to GGUF
|
||||
|
|
|
@ -11,9 +11,30 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
|
||||
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
|
||||
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
|
||||
|
||||
#if !LLAMA_CPP_EXTENDED
|
||||
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
|
||||
#endif
|
||||
|
||||
std::vector<float> read_2d_tensor( struct ggml_tensor* tensor ) {
|
||||
size_t size = tensor->ne[0] * tensor->ne[1];
|
||||
std::vector<float> res( size );
|
||||
|
||||
auto* qtype = ggml_get_type_traits(tensor->type);
|
||||
// dequantize if needed
|
||||
if ( ggml_is_quantized(tensor->type) ) {
|
||||
qtype->to_float(tensor->data, res.data(), res.size());
|
||||
} else {
|
||||
memcpy( res.data(), tensor->data, res.size() * sizeof(float) );
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// stores the raw inputs to be fed
|
||||
struct input_t {
|
||||
|
@ -26,94 +47,159 @@ struct input_t {
|
|||
std::vector<std::vector<llama_token>> prom = {};
|
||||
std::vector<std::vector<llama_token>> resp = {};
|
||||
};
|
||||
|
||||
/*
|
||||
[(0, 256), 'text_emb.weight', 'classifiers.proj.9.weight', None],
|
||||
[(256, 264), 'rvq_l_emb.weight', None, '<|RVQ:{l}|>'],
|
||||
[(264, 270), 'langs_emb.weight', None, '<|lang:{lang}|>'],
|
||||
[(270, 279), 'tasks_emb.weight', None, '<|task:{task}|>'],
|
||||
[(279, 290), 'len_emb.weight', 'classifiers.proj.10.weight', '<|len:{id}|>'],
|
||||
[(290, 291), 'tones_emb.weight', None, '<|tone:{tone}|>'],
|
||||
[(291, 292), 'sep', None, '<|sep|>'],
|
||||
[(292, 1316), 'proms_emb.embeddings.0.weight', None, '<|P|0|{id}|>'],
|
||||
[(1316, 2340), 'proms_emb.embeddings.1.weight', None, '<|P|1|{id}|>'],
|
||||
[(2340, 3364), 'proms_emb.embeddings.2.weight', None, '<|P|2|{id}|>'],
|
||||
[(3364, 4388), 'proms_emb.embeddings.3.weight', None, '<|P|3|{id}|>'],
|
||||
[(4388, 5412), 'proms_emb.embeddings.4.weight', None, '<|P|4|{id}|>'],
|
||||
[(5412, 6436), 'proms_emb.embeddings.5.weight', None, '<|P|5|{id}|>'],
|
||||
[(6436, 7460), 'proms_emb.embeddings.6.weight', None, '<|P|6|{id}|>'],
|
||||
[(7460, 8484), 'proms_emb.embeddings.7.weight', None, '<|P|7|{id}|>'],
|
||||
[(8484, 9509), 'resps_emb.embeddings.0.weight', 'classifiers.proj.0.weight', '<|R|AR|0:0|{id}|>'],
|
||||
[(9509, 10533), 'resps_emb.embeddings.1.weight', 'classifiers.proj.1.weight', '<|R|NAR|0:1|{id}|>'],
|
||||
[(10533, 11557), 'resps_emb.embeddings.2.weight', 'classifiers.proj.2.weight', '<|R|NAR|1:2|{id}|>'],
|
||||
[(11557, 12581), 'resps_emb.embeddings.3.weight', 'classifiers.proj.3.weight', '<|R|NAR|2:3|{id}|>'],
|
||||
[(12581, 13605), 'resps_emb.embeddings.4.weight', 'classifiers.proj.4.weight', '<|R|NAR|3:4|{id}|>'],
|
||||
[(13605, 14629), 'resps_emb.embeddings.5.weight', 'classifiers.proj.5.weight', '<|R|NAR|4:5|{id}|>'],
|
||||
[(14629, 15653), 'resps_emb.embeddings.6.weight', 'classifiers.proj.6.weight', '<|R|NAR|5:6|{id}|>'],
|
||||
[(15653, 16677), 'resps_emb.embeddings.7.weight', 'classifiers.proj.7.weight', '<|R|NAR|6:7|{id}|>'],
|
||||
[(16677, 17702), 'resps_emb.embeddings.8.weight', 'classifiers.proj.8.weight', '<|R|NAR|0:0|{id}|>']
|
||||
*/
|
||||
|
||||
// handles all the cringe logic of slicing embeddings
|
||||
struct ranges_t {
|
||||
std::string name;
|
||||
|
||||
uint32_t start;
|
||||
uint32_t end;
|
||||
|
||||
int32_t classifier_idx = -1;
|
||||
};
|
||||
ranges_t io_ranges[] = {
|
||||
{ "text", 0, 256, 9, },
|
||||
{ "rvq_l", 256, 264, -1, },
|
||||
{ "lang", 264, 270, -1, },
|
||||
{ "task", 270, 279, -1, },
|
||||
{ "len", 279, 290, 10, },
|
||||
{ "tone", 290, 291, -1, },
|
||||
{ "sep", 291, 292, -1, },
|
||||
|
||||
{ "prom|0", 292, 1316, -1, },
|
||||
{ "prom|1", 1316, 2340, -1, },
|
||||
{ "prom|2", 2340, 3364, -1, },
|
||||
{ "prom|3", 3364, 4388, -1, },
|
||||
{ "prom|4", 4388, 5412, -1, },
|
||||
{ "prom|5", 5412, 6436, -1, },
|
||||
{ "prom|6", 6436, 7460, -1, },
|
||||
{ "prom|7", 7460, 8484, -1, },
|
||||
|
||||
{ "resps|AR:0 8484, 9509, 0,:0", },
|
||||
{ "resps|NAR:0 9509, 10533, 1,:1", },
|
||||
{ "resps|NAR:1: 10533, 11557, 2,2", },
|
||||
{ "resps|NAR:2: 11557, 12581, 3,3", },
|
||||
{ "resps|NAR:3: 12581, 13605, 4,4", },
|
||||
{ "resps|NAR:4: 13605, 14629, 5,5", },
|
||||
{ "resps|NAR:5: 14629, 15653, 6,6", },
|
||||
{ "resps|NAR:6: 15653, 16677, 7,7", },
|
||||
{ "resps|NAR:0: 16677, 17702, 8,0", },
|
||||
};
|
||||
|
||||
struct embeddings_t {
|
||||
int n_embd;
|
||||
int n_vocab;
|
||||
|
||||
ranges_t range;
|
||||
std::vector<float> embds;
|
||||
};
|
||||
struct embeddings_map_t {
|
||||
int n_embd = 0;
|
||||
int n_vocab = 0;
|
||||
float* embds = NULL;
|
||||
|
||||
// mapping
|
||||
std::unordered_map<std::string, embeddings_t> mapped_embeddings;
|
||||
|
||||
int text_embd_start = 0; // <unk>
|
||||
int rvq_level_embd_start = 17666; // <|RVQ:0>
|
||||
int len_embd_start = 17674; // <|len:0|>
|
||||
int lang_embd_start = 17686; // <|lang:en|>
|
||||
int task_embd_start = 17692; // <|task:tts|>
|
||||
int sep_embd_start = 17685; // <|sep|>
|
||||
int prom_embd_start[8] = {
|
||||
256 + (1024 * 0), // <|P|0:0|>
|
||||
256 + (1024 * 1), // <|P|1:0|>
|
||||
256 + (1024 * 2), // <|P|2:0|>
|
||||
256 + (1024 * 3), // <|P|3:0|>
|
||||
256 + (1024 * 4), // <|P|4:0|>
|
||||
256 + (1024 * 5), // <|P|5:0|>
|
||||
256 + (1024 * 6), // <|P|6:0|>
|
||||
256 + (1024 * 7), // <|P|7:0|>
|
||||
};
|
||||
int resp_embd_start[9] = {
|
||||
8448, // <|AR|0:0|>
|
||||
9473, // <|NAR|0:0|>
|
||||
10498 + (1024 * 0), // <|NAR|0:1|>
|
||||
10498 + (1024 * 1), // <|NAR|1:2|>
|
||||
10498 + (1024 * 2), // <|NAR|2:3|>
|
||||
10498 + (1024 * 3), // <|NAR|3:4|>
|
||||
10498 + (1024 * 4), // <|NAR|4:5|>
|
||||
10498 + (1024 * 5), // <|NAR|5:6|>
|
||||
10498 + (1024 * 6), // <|NAR|6:7|>
|
||||
};
|
||||
|
||||
float* text_embds = NULL; // &embds[text_embd_start * n_embd];
|
||||
float* rvq_level_embd = NULL; // &embds[rvq_level_embd_start * n_embd];
|
||||
float* len_embd = NULL; // &embds[len_embd_start * n_embd];
|
||||
float* lang_embd = NULL; // &embds[lang_embd_start * n_embd];
|
||||
float* task_embd = NULL; // &embds[task_embd_start * n_embd];
|
||||
float* sep_embd = NULL; // &embds[sep_embd_start * n_embd];
|
||||
|
||||
float* prom_embds[8] = {
|
||||
NULL, // &embds[prom_embd_start[0] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[1] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[2] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[3] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[4] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[5] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[6] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[7] * n_embd],
|
||||
};
|
||||
float* resps_embds[9] = {
|
||||
NULL, // &embds[resp_embd_start[0] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[1] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[2] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[3] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[4] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[5] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[6] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[7] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[8] * n_embd],
|
||||
};
|
||||
|
||||
embeddings_t( int n_embd = 0, int n_vocab = 0, float* embds = NULL ) {
|
||||
init( n_embd, n_vocab, embds );
|
||||
const embeddings_t& get_embeddings( const std::string& name ) {
|
||||
return mapped_embeddings[name];
|
||||
}
|
||||
const float* get_embeddings_p( const std::string& name ) {
|
||||
return mapped_embeddings[name].embds.data();
|
||||
}
|
||||
|
||||
void init( int n_embd, int n_vocab, float* embds = NULL ) {
|
||||
if ( !n_embd || !n_vocab || !embds ) return;
|
||||
void init( llama_model* model ) {
|
||||
this->n_embd = llama_n_embd( model );
|
||||
this->n_vocab = llama_n_vocab( model );
|
||||
|
||||
this->n_embd = n_embd;
|
||||
this->n_vocab = n_vocab;
|
||||
this->embds = embds;
|
||||
// to-do: figure a nicer way to do this
|
||||
#if LLAMA_CPP_USE_VALL_E_ARCH
|
||||
mapped_embeddings["text"] = { n_embd, 0, { "text", 0, 0, 9, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 0)) };
|
||||
mapped_embeddings["rvq_l"] = { n_embd, 0, { "rvq_l", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 1)) };
|
||||
mapped_embeddings["langs"] = { n_embd, 0, { "langs", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 2)) };
|
||||
mapped_embeddings["tasks"] = { n_embd, 0, { "tasks", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 3)) };
|
||||
mapped_embeddings["len"] = { n_embd, 0, { "len", 0, 0, 10, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 4)) };
|
||||
mapped_embeddings["tones"] = { n_embd, 0, { "tones", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 5)) };
|
||||
mapped_embeddings["sep"] = { n_embd, 0, { "sep", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 6)) };
|
||||
|
||||
text_embds = &embds[text_embd_start * n_embd];
|
||||
rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
|
||||
len_embd = &embds[len_embd_start * n_embd];
|
||||
lang_embd = &embds[lang_embd_start * n_embd];
|
||||
task_embd = &embds[task_embd_start * n_embd];
|
||||
sep_embd = &embds[sep_embd_start * n_embd];
|
||||
mapped_embeddings["prom|0"] = { n_embd, 0, { "prom|0", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 0)) };
|
||||
mapped_embeddings["prom|1"] = { n_embd, 0, { "prom|1", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 1)) };
|
||||
mapped_embeddings["prom|2"] = { n_embd, 0, { "prom|2", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 2)) };
|
||||
mapped_embeddings["prom|3"] = { n_embd, 0, { "prom|3", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 3)) };
|
||||
mapped_embeddings["prom|4"] = { n_embd, 0, { "prom|4", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 4)) };
|
||||
mapped_embeddings["prom|5"] = { n_embd, 0, { "prom|5", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 5)) };
|
||||
mapped_embeddings["prom|6"] = { n_embd, 0, { "prom|6", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 6)) };
|
||||
mapped_embeddings["prom|7"] = { n_embd, 0, { "prom|7", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 7)) };
|
||||
|
||||
mapped_embeddings["resps|AR:0:0"] = { n_embd, 0, { "resps|AR:0:0", 0, 0, 0, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 0)) };
|
||||
mapped_embeddings["resps|NAR:0:1"] = { n_embd, 0, { "resps|NAR:0:1", 0, 0, 1, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 1)) };
|
||||
mapped_embeddings["resps|NAR:1:2"] = { n_embd, 0, { "resps|NAR:1:2", 0, 0, 2, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 2)) };
|
||||
mapped_embeddings["resps|NAR:2:3"] = { n_embd, 0, { "resps|NAR:2:3", 0, 0, 3, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 3)) };
|
||||
mapped_embeddings["resps|NAR:3:4"] = { n_embd, 0, { "resps|NAR:3:4", 0, 0, 4, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 4)) };
|
||||
mapped_embeddings["resps|NAR:4:5"] = { n_embd, 0, { "resps|NAR:4:5", 0, 0, 5, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 5)) };
|
||||
mapped_embeddings["resps|NAR:5:6"] = { n_embd, 0, { "resps|NAR:5:6", 0, 0, 6, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 6)) };
|
||||
mapped_embeddings["resps|NAR:6:7"] = { n_embd, 0, { "resps|NAR:6:7", 0, 0, 7, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 7)) };
|
||||
mapped_embeddings["resps|NAR:0:0"] = { n_embd, 0, { "resps|NAR:0:0", 0, 0, 8, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 8)) };
|
||||
|
||||
for ( auto i = 0; i < 8; ++i ) prom_embds[i] = &embds[prom_embd_start[i] * n_embd];
|
||||
for ( auto i = 0; i < 9; ++i ) resps_embds[i] = &embds[resp_embd_start[i] * n_embd];
|
||||
// update values
|
||||
for ( auto& pair : mapped_embeddings ) {
|
||||
auto& k = pair.first;
|
||||
auto& v = pair.second;
|
||||
auto& embds = v.embds;
|
||||
|
||||
v.n_vocab = embds.size() / n_embd;
|
||||
v.range.end = v.n_vocab;
|
||||
}
|
||||
#else
|
||||
|
||||
#if LLAMA_CPP_EXTENDED
|
||||
auto* tensor = llama_get_embedding_weights( model );
|
||||
#else
|
||||
auto* tensor = model->tok_embd;
|
||||
#endif
|
||||
|
||||
// prepare slices
|
||||
std::vector<float> raw_embeddings = read_2d_tensor( tensor );
|
||||
for ( auto& range : io_ranges ) {
|
||||
mapped_embeddings[range.name] = {
|
||||
n_embd,
|
||||
range.end - range.start,
|
||||
range,
|
||||
std::vector<float>( raw_embeddings.data() + range.start, raw_embeddings.data() + range.end )
|
||||
};
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// maps embeddings easily
|
||||
std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>& tokens, int n_embd, float* embds ) {
|
||||
std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>& tokens, int n_embd, const float* embds ) {
|
||||
std::vector<std::vector<float>> embedded( tokens.size() );
|
||||
for ( auto i = 0; i < tokens.size(); ++i ) {
|
||||
embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) );
|
||||
|
@ -123,7 +209,7 @@ std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>&
|
|||
|
||||
// handles adding either a token OR the embedding of that token into the batch
|
||||
// this really, really helps avoid needing to abuse the tokenizer
|
||||
void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} ) {
|
||||
void batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} ) {
|
||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
||||
|
||||
// insert raw embedding instead
|
||||
|
@ -143,7 +229,7 @@ void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, ll
|
|||
batch.logits[batch.n_tokens] = output ? 1 : 0;
|
||||
|
||||
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, batch.pos[batch.n_tokens], &batch.embd[batch.n_tokens], batch.logits[batch.n_tokens] );
|
||||
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens-1, id, pos, embds, output );
|
||||
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, pos, embds, output );
|
||||
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
@ -265,7 +351,7 @@ const int MAX_DURATION = 75; // * 12;
|
|||
const int CTX_SIZE = 2048;
|
||||
|
||||
// sums embeddings over a 2D "tensor"
|
||||
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) {
|
||||
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM ) {
|
||||
std::vector<std::vector<float>> res( input.size() );
|
||||
res.resize( input[0].size() );
|
||||
for ( auto& e : res ) e.resize( n_embd );
|
||||
|
@ -274,9 +360,9 @@ std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<ll
|
|||
int offset = 0;
|
||||
// handles the cringe logic I have
|
||||
if ( mode == EMBEDDING_MODE_RESP_AR_NAR ) {
|
||||
offset = input.size() == 1 ? 0 : 2;
|
||||
offset = input.size() == 1 ? 0 : 1;
|
||||
} else if ( mode == EMBEDDING_MODE_RESP_NAR_LEN ) {
|
||||
offset = input.size() == 1 ? 1 : 2;
|
||||
offset = input.size() == 1 ? 8 : 1;
|
||||
}
|
||||
// get tokens
|
||||
auto& tokens = input[l];
|
||||
|
@ -293,40 +379,69 @@ std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<ll
|
|||
return res;
|
||||
}
|
||||
|
||||
void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_map, int mode ) {
|
||||
void fill_batch( llama_batch& batch, input_t& input, embeddings_map_t& embeddings_map, int mode ) {
|
||||
// keeps track of the position for each sequence
|
||||
size_t pos = 0;
|
||||
auto n_embd = embeddings_map.n_embd;
|
||||
|
||||
const float* text_embds = embeddings_map.get_embeddings_p("text");
|
||||
const float* rvq_l_embds = embeddings_map.get_embeddings_p("rvq_l");
|
||||
const float* lang_embds = embeddings_map.get_embeddings_p("lang");
|
||||
const float* task_embds = embeddings_map.get_embeddings_p("task");
|
||||
const float* len_embds = embeddings_map.get_embeddings_p("len");
|
||||
const float* tone_embds = embeddings_map.get_embeddings_p("tone");
|
||||
const float* sep_embds = embeddings_map.get_embeddings_p("sep");
|
||||
const float* prom_embds[] = {
|
||||
embeddings_map.get_embeddings_p("prom|0"),
|
||||
embeddings_map.get_embeddings_p("prom|1"),
|
||||
embeddings_map.get_embeddings_p("prom|2"),
|
||||
embeddings_map.get_embeddings_p("prom|3"),
|
||||
embeddings_map.get_embeddings_p("prom|4"),
|
||||
embeddings_map.get_embeddings_p("prom|5"),
|
||||
embeddings_map.get_embeddings_p("prom|6"),
|
||||
embeddings_map.get_embeddings_p("prom|7"),
|
||||
};
|
||||
const float* resp_embds[] = {
|
||||
embeddings_map.get_embeddings_p("resps|AR:0:0"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:0:1"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:1:2"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:2:3"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:3:4"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:4:5"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:5:6"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:6:7"),
|
||||
embeddings_map.get_embeddings_p("resps|NAR:0:0"),
|
||||
};
|
||||
|
||||
// insert text tokens
|
||||
for ( auto& id : input.phn ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
||||
for ( auto& id : input.phn ) batch_add( batch, id, n_embd, text_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
|
||||
pos = 0;
|
||||
// insert lang token
|
||||
batch_add( batch, input.lang, n_embd, embeddings_map.lang_embd, pos++, false );
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
||||
batch_add( batch, input.lang, n_embd, lang_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
|
||||
pos = 0;
|
||||
// insert rvq level token
|
||||
batch_add( batch, input.rvq_l, n_embd, embeddings_map.rvq_level_embd, pos++, false );
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
||||
batch_add( batch, input.rvq_l, n_embd, rvq_l_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
|
||||
pos = 0;
|
||||
// insert prom tokens
|
||||
auto summed_proms_embds = sum_embeddings( input.prom, n_embd, input.rvq_l, embeddings_map.prom_embds );
|
||||
auto summed_proms_embds = sum_embeddings( input.prom, n_embd, input.rvq_l, prom_embds );
|
||||
for ( auto i = 0; i < summed_proms_embds.size(); ++i ) {
|
||||
batch_add( batch, -1, n_embd, &summed_proms_embds[i][0], pos++, false );
|
||||
}
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, mode == INFERENCE_MODE_AR ); // set as the last logit if AR
|
||||
batch_add( batch, 0, n_embd, sep_embds, pos++, mode == INFERENCE_MODE_AR ); // set as the last logit if AR
|
||||
pos = 0;
|
||||
|
||||
// input starting len token
|
||||
if ( input.task == "len" ) {
|
||||
batch_add( batch, 0, n_embd, embeddings_map.len_embd, pos++, true );
|
||||
batch_add( batch, 0, n_embd, len_embds, pos++, true );
|
||||
pos = 0;
|
||||
}
|
||||
|
||||
// insert resp tokens
|
||||
if ( !input.resp.empty() ) {
|
||||
auto summed_resps_embds = sum_embeddings( input.resp, n_embd, input.rvq_l, embeddings_map.resps_embds, mode == INFERENCE_MODE_AR ? EMBEDDING_MODE_RESP_AR_NAR : EMBEDDING_MODE_RESP_NAR_LEN );
|
||||
auto summed_resps_embds = sum_embeddings( input.resp, n_embd, input.rvq_l, resp_embds, mode == INFERENCE_MODE_AR ? EMBEDDING_MODE_RESP_AR_NAR : EMBEDDING_MODE_RESP_NAR_LEN );
|
||||
for ( auto i = 0; i < summed_resps_embds.size(); ++i ) {
|
||||
batch_add( batch, -1, n_embd, &summed_resps_embds[i][0], pos++, true );
|
||||
}
|
||||
|
@ -335,7 +450,7 @@ void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_ma
|
|||
}
|
||||
|
||||
// generation code, should handle all modalities easily
|
||||
std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) {
|
||||
std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_map_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) {
|
||||
llama_batch batch = llama_batch_init( 22500, embeddings_map.n_embd, 22500 );
|
||||
|
||||
// Decoding loop
|
||||
|
@ -364,34 +479,50 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
|
|||
return {};
|
||||
}
|
||||
|
||||
float* embds = NULL;
|
||||
int logit_range[2];
|
||||
if ( mode == INFERENCE_MODE_AR ) {
|
||||
logit_range[0] = embeddings_map.resp_embd_start[0];
|
||||
logit_range[1] = embeddings_map.resp_embd_start[1];
|
||||
|
||||
embds = embeddings_map.resps_embds[0];
|
||||
|
||||
stop_token = embeddings_map.resp_embd_start[1] - 1; // <|AR|0:STOP|>
|
||||
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
|
||||
logit_range[0] = embeddings_map.resp_embd_start[1];
|
||||
logit_range[1] = embeddings_map.resp_embd_start[2];
|
||||
|
||||
embds = embeddings_map.resps_embds[1];
|
||||
const float* embds = NULL;
|
||||
ranges_t range;
|
||||
|
||||
stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|>
|
||||
if ( mode == INFERENCE_MODE_AR ) {
|
||||
auto& embeddings = embeddings_map.get_embeddings("resps|AR:0:0");
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
stop_token = range.end - 1;
|
||||
|
||||
// llama_set_classifier_index( ctx, 0 );
|
||||
|
||||
printf("Generating in %s mode (%i:%i)\n", "AR", range.start, range.end);
|
||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l-1];
|
||||
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l-1];
|
||||
std::string k_embds[] = {
|
||||
"resps|NAR:0:0", // invalid
|
||||
"resps|NAR:0:1",
|
||||
"resps|NAR:1:2",
|
||||
"resps|NAR:2:3",
|
||||
"resps|NAR:3:4",
|
||||
"resps|NAR:4:5",
|
||||
"resps|NAR:5:6",
|
||||
"resps|NAR:6:7",
|
||||
};
|
||||
auto& embeddings = embeddings_map.get_embeddings(k_embds[rvq_l]);
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
|
||||
embds = embeddings_map.resps_embds[2];
|
||||
// llama_set_classifier_index( ctx, rvq_l );
|
||||
printf("Generating in %s mode (%i:%i)\n", "NAR", range.start, range.end);
|
||||
} else if ( mode == INFERENCE_MODE_LEN ) {
|
||||
logit_range[0] = embeddings_map.len_embd_start;
|
||||
logit_range[1] = embeddings_map.len_embd_start + 11;
|
||||
auto& embeddings = embeddings_map.get_embeddings("len");
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
stop_token = range.end - 1;
|
||||
|
||||
embds = embeddings_map.len_embd;
|
||||
|
||||
stop_token = embeddings_map.len_embd_start + 10;
|
||||
// llama_set_classifier_index( ctx, 10 );
|
||||
printf("Generating in %s mode (%i:%i)\n", "len", range.start, range.end);
|
||||
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
|
||||
auto& embeddings = embeddings_map.get_embeddings("NAR:0:0");
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
|
||||
// llama_set_classifier_index( ctx, 8 );
|
||||
printf("Generating in %s mode (%i:%i)\n", "NAR-len", range.start, range.end);
|
||||
}
|
||||
|
||||
llama_set_causal_attn( ctx, n_logits == 1 );
|
||||
|
@ -408,32 +539,34 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
|
|||
for ( auto i = n_logits; i > 0; --i ) {
|
||||
// filter logits
|
||||
auto* logits = llama_get_logits_ith( ctx, -i );
|
||||
|
||||
// ensures only tokens within our designated range are used
|
||||
#if !LLAMA_CPP_USE_VALL_E_ARCH
|
||||
for ( auto i = 0; i < embeddings_map.n_vocab; ++i ) {
|
||||
// out of target logit range, set to never happen
|
||||
if ( i < logit_range[0] || i >= logit_range[1] ) logits[i] = -INFINITY;
|
||||
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
|
||||
}
|
||||
#endif
|
||||
|
||||
// sample the next token
|
||||
auto t = llama_sampler_sample(smpl, ctx, -i);
|
||||
|
||||
if ( verbose ) {
|
||||
// print token for debugging
|
||||
char buf[256];
|
||||
int n = llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
|
||||
if ( n < 256 ) buf[n] = '\0';
|
||||
printf("%s\n", buf );
|
||||
}
|
||||
// offset back into range
|
||||
#if !LLAMA_CPP_USE_VALL_E_ARCH
|
||||
t -= range.start;
|
||||
#endif
|
||||
|
||||
printf("%i: %i\n", n_decode, t);
|
||||
|
||||
// is stop token
|
||||
if ( t == stop_token ) {
|
||||
printf("STOPPED\n");
|
||||
max_tokens = 0;
|
||||
break;
|
||||
}
|
||||
|
||||
// offset into range
|
||||
t -= logit_range[0];
|
||||
|
||||
// store token
|
||||
output_tokens.emplace_back(t);
|
||||
// update batch with token
|
||||
batch_add( batch, t, embeddings_map.n_embd, embds, output_tokens.size(), true );
|
||||
}
|
||||
}
|
||||
|
@ -460,7 +593,7 @@ int main(int argc, char ** argv) {
|
|||
int32_t ngl = 0;
|
||||
int modality = MODALITY_NAR_LEN;
|
||||
input_t input{};
|
||||
embeddings_t embeddings_map{};
|
||||
embeddings_map_t embeddings_map{};
|
||||
|
||||
// input.phonemes = "hˈɛloː ʋˈɔrlt";
|
||||
input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
|
||||
|
@ -473,19 +606,6 @@ int main(int argc, char ** argv) {
|
|||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
|
||||
if (!ectx) {
|
||||
printf("%s: error during loading model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
encodec_set_target_bandwidth(ectx, 6);
|
||||
encodec_set_sample_rate(ectx, 24000);
|
||||
|
||||
// load wavform
|
||||
input.prom = encode_audio_from_disk(ectx, input_prompt_path);
|
||||
//input.resp = encode_audio_from_disk(ectx, output_response_path);
|
||||
|
||||
// initialize the models
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.n_gpu_layers = ngl;
|
||||
|
@ -524,19 +644,25 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
|
||||
|
||||
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
|
||||
if (!ectx) {
|
||||
printf("%s: error during loading model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
encodec_set_target_bandwidth(ectx, 6);
|
||||
encodec_set_sample_rate(ectx, 24000);
|
||||
|
||||
// load wavform
|
||||
input.prom = encode_audio_from_disk(ectx, input_prompt_path);
|
||||
//input.resp = encode_audio_from_disk(ectx, output_response_path);
|
||||
|
||||
// prepare batch
|
||||
auto n_embd = llama_n_embd( model );
|
||||
auto n_vocab = llama_n_vocab( model );
|
||||
|
||||
// grab input embeddings
|
||||
std::vector<float> embds( n_embd * n_vocab );
|
||||
auto* qtype = ggml_get_type_traits(model->tok_embd->type);
|
||||
// dequantize if needed
|
||||
if ( ggml_is_quantized(model->tok_embd->type) ) {
|
||||
qtype->to_float(model->tok_embd->data, embds.data(), embds.size());
|
||||
}
|
||||
// update mapping
|
||||
embeddings_map.init( n_embd, n_vocab, embds.data() );
|
||||
embeddings_map.init( model );
|
||||
|
||||
// tokenize phonemes
|
||||
// to-do: make this work, the vocab does not work
|
||||
|
@ -573,7 +699,7 @@ int main(int argc, char ** argv) {
|
|||
// fill with mask tokens
|
||||
input.resp.resize(1);
|
||||
for ( auto i = 0; i < len; ++i ) {
|
||||
input.resp[0].emplace_back( embeddings_map.resp_embd_start[3] - 1 ); // fill with masked tokens
|
||||
input.resp[0].emplace_back( 1024 ); // fill with masked tokens
|
||||
}
|
||||
|
||||
// inference NAR-len 0
|
||||
|
|
334
vall_e/export.py
334
vall_e/export.py
|
@ -12,7 +12,7 @@ from .utils.io import torch_save, torch_load, json_read, json_write, Path
|
|||
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
||||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||
@torch.no_grad()
|
||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||
def convert_to_hf_llama( state_dict, config = None, save_path = None ):
|
||||
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
|
||||
|
||||
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
|
||||
|
@ -21,6 +21,9 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
|
||||
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
||||
|
||||
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
|
||||
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
|
||||
|
||||
# the new tokenizer to use
|
||||
tokenizer = {}
|
||||
tokenizer_vocab = {}
|
||||
|
@ -37,20 +40,6 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
}
|
||||
}
|
||||
|
||||
l_tokens = [
|
||||
n_text_tokens, # text
|
||||
n_audio_tokens * n_resp_levels, # prom
|
||||
(n_audio_tokens + 1) * 2, # resp: AR + NAR-len (with stop/mask)
|
||||
(n_audio_tokens) * (n_resp_levels - 1), # NAR
|
||||
n_resp_levels, # RVQ level
|
||||
n_len_tokens, # len tokens
|
||||
1, # separator
|
||||
n_lang_tokens, # langs
|
||||
n_task_tokens, # tasks
|
||||
]
|
||||
|
||||
n_tokens = sum(l_tokens)
|
||||
|
||||
lang_map = [
|
||||
"en",
|
||||
"ja",
|
||||
|
@ -70,9 +59,47 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
"eoe",
|
||||
"stt",
|
||||
]
|
||||
tone_map = [
|
||||
"neutral",
|
||||
]
|
||||
|
||||
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
|
||||
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
|
||||
# (start, end), embedding, classifier, token_format
|
||||
mapping = [
|
||||
[(0, 0), "text_emb.weight", "classifiers.proj.9.weight", None],
|
||||
[(0, 0), "rvq_l_emb.weight", None, "<|RVQ:{l}|>"],
|
||||
[(0, 0), "langs_emb.weight", None, "<|lang:{lang}|>"],
|
||||
[(0, 0), "tasks_emb.weight", None, "<|task:{task}|>"],
|
||||
[(0, 0), "len_emb.weight", "classifiers.proj.10.weight", "<|len:{id}|>"],
|
||||
[(0, 0), "tones_emb.weight", None, "<|tone:{tone}|>"],
|
||||
[(0, 0), "sep", None, "<|sep|>"],
|
||||
|
||||
[(0, 0), "proms_emb.embeddings.0.weight", None, "<|P|0|{id}|>"],
|
||||
[(0, 0), "proms_emb.embeddings.1.weight", None, "<|P|1|{id}|>"],
|
||||
[(0, 0), "proms_emb.embeddings.2.weight", None, "<|P|2|{id}|>"],
|
||||
[(0, 0), "proms_emb.embeddings.3.weight", None, "<|P|3|{id}|>"],
|
||||
[(0, 0), "proms_emb.embeddings.4.weight", None, "<|P|4|{id}|>"],
|
||||
[(0, 0), "proms_emb.embeddings.5.weight", None, "<|P|5|{id}|>"],
|
||||
[(0, 0), "proms_emb.embeddings.6.weight", None, "<|P|6|{id}|>"],
|
||||
[(0, 0), "proms_emb.embeddings.7.weight", None, "<|P|7|{id}|>"],
|
||||
|
||||
[(0, 0), "resps_emb.embeddings.0.weight", "classifiers.proj.0.weight", "<|R|AR|0:0|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.1.weight", "classifiers.proj.1.weight", "<|R|NAR|0:1|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.2.weight", "classifiers.proj.2.weight", "<|R|NAR|1:2|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.3.weight", "classifiers.proj.3.weight", "<|R|NAR|2:3|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.4.weight", "classifiers.proj.4.weight", "<|R|NAR|3:4|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.5.weight", "classifiers.proj.5.weight", "<|R|NAR|4:5|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.6.weight", "classifiers.proj.6.weight", "<|R|NAR|5:6|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.7.weight", "classifiers.proj.7.weight", "<|R|NAR|6:7|{id}|>"],
|
||||
[(0, 0), "resps_emb.embeddings.8.weight", "classifiers.proj.8.weight", "<|R|NAR|0:0|{id}|>"],
|
||||
]
|
||||
|
||||
n_tokens = 0
|
||||
# to-do: figure out discrepancy
|
||||
for i, m in enumerate( mapping ):
|
||||
k_embd = mapping[i][1]
|
||||
embds = state_dict['module'][k_embd] if k_embd in state_dict['module'] else None
|
||||
|
||||
n_tokens += 1 if embds.dim() == 1 else embds.shape[0]
|
||||
|
||||
embedding = torch.nn.Embedding( n_tokens, model_dim )
|
||||
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
|
||||
|
@ -80,113 +107,49 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
if not split_classifiers:
|
||||
classifier.weight[:] = state_dict['module']['classifier.weight'][:]
|
||||
|
||||
# to-do: ignore classifier for RVQ level 7
|
||||
# update ranges
|
||||
start = 0
|
||||
for i, m in enumerate( mapping ):
|
||||
# get previous start
|
||||
k_embd = mapping[i][1]
|
||||
k_head = mapping[i][2]
|
||||
token_format = mapping[i][3]
|
||||
|
||||
# inject text tokens
|
||||
token_start = 0
|
||||
token_end = l_tokens[0]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight']
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
|
||||
# tokenizer already has these tokens
|
||||
embds = state_dict['module'][k_embd] if k_embd in state_dict['module'] else None
|
||||
head = state_dict['module'][k_head] if k_head in state_dict['module'] else None
|
||||
|
||||
# inject prom tokens
|
||||
token_start = token_end
|
||||
token_end += l_tokens[1]
|
||||
for l in range(n_resp_levels):
|
||||
start = token_start + (l * n_audio_tokens)
|
||||
end = start + n_audio_tokens
|
||||
embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight']
|
||||
# there's no corresponding classifier
|
||||
#classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
#classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_vocab[f'<|P|{l}:{t}|>'] = start + t
|
||||
# expand if 1D
|
||||
if embds.dim() == 1:
|
||||
embds = embds.unsqueeze(0)
|
||||
|
||||
# inject AR
|
||||
token_start = token_end
|
||||
token_end += l_tokens[2] // 2
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight']
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
|
||||
tokenizer_vocab[f'<AR|0:0|STOP|>'] = token_start + 1024
|
||||
tokens = embds.shape[0]
|
||||
|
||||
# inject NAR-len
|
||||
token_start = token_end
|
||||
token_end += l_tokens[2] // 2
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_vocab[f'<|NAR|0:0|{t}|>'] = token_start + t
|
||||
tokenizer_vocab[f'<|NAR|0:0|STOP|>'] = token_start + 1024
|
||||
|
||||
# inject NAR
|
||||
token_start = token_end
|
||||
token_end += l_tokens[3]
|
||||
for l in range(1, n_resp_levels):
|
||||
start = token_start + ((l-1) * n_audio_tokens)
|
||||
end = start + n_audio_tokens
|
||||
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
|
||||
if split_classifiers:
|
||||
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
|
||||
|
||||
# inject RVQ level
|
||||
token_start = token_end
|
||||
token_end += l_tokens[4]
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight']
|
||||
# there is no corresponding classifier
|
||||
for l in range(n_resp_levels):
|
||||
tokenizer_vocab[f'<|RVQ:{l}|>'] = token_start + l
|
||||
|
||||
# inject len
|
||||
token_start = token_end
|
||||
token_end += l_tokens[5]
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight']
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
|
||||
for t in range(n_len_tokens):
|
||||
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
|
||||
|
||||
# inject sep
|
||||
token_start = token_end
|
||||
token_end += l_tokens[6]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['sep']
|
||||
tokenizer_vocab['<|sep|>'] = token_start
|
||||
# there is no corresponding classifier
|
||||
|
||||
# inject langs
|
||||
token_start = token_end
|
||||
token_end += l_tokens[7]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight']
|
||||
for l in range(n_lang_tokens):
|
||||
lang = lang_map[l]
|
||||
tokenizer_vocab[f'<|lang:{lang}|>'] = token_start + l
|
||||
# there is no corresponding classifier
|
||||
|
||||
# inject tasks
|
||||
token_start = token_end
|
||||
token_end += l_tokens[8]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight']
|
||||
for l in range(n_task_tokens):
|
||||
task = task_map[l]
|
||||
tokenizer_vocab[f'<|task:{task}|>'] = token_start + l
|
||||
# there is no corresponding classifier
|
||||
if embds is not None:
|
||||
embedding.weight[start:start+tokens] = embds
|
||||
|
||||
if split_classifiers and head is not None:
|
||||
classifier.weight[start:start+head.shape[0]] = head
|
||||
|
||||
if token_format is not None:
|
||||
for idx in range(0, tokens):
|
||||
# RVQ level
|
||||
if "{l}" in token_format:
|
||||
token = token_format.format(l=idx)
|
||||
elif "{lang}" in token_format:
|
||||
token = token_format.format(lang=lang_map[idx])
|
||||
elif "{task}" in token_format:
|
||||
token = token_format.format(task=task_map[idx])
|
||||
elif "{tone}" in token_format:
|
||||
token = token_format.format(tone=tone_map[idx])
|
||||
elif "{id}" in token_format:
|
||||
token = token_format.format(id=idx)
|
||||
else:
|
||||
token = token_format
|
||||
tokenizer_vocab[token] = idx + start
|
||||
|
||||
end = start + tokens
|
||||
mapping[i][0] = (start, end)
|
||||
start = end
|
||||
|
||||
model_dict = {}
|
||||
# filter out the underlying model weights and extract them
|
||||
|
@ -225,7 +188,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
# write config.json
|
||||
json_write({
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
"LLaMAForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
|
@ -251,6 +214,130 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
"vocab_size": n_tokens
|
||||
}, out_dir / "config.json", pretty=True )
|
||||
|
||||
return state_dict
|
||||
|
||||
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
||||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||
@torch.no_grad()
|
||||
def convert_to_hf_custom( state_dict, config = None, save_path = None ):
|
||||
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
|
||||
|
||||
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
|
||||
n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0]
|
||||
n_len_tokens = 11
|
||||
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
|
||||
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
||||
|
||||
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
|
||||
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
|
||||
|
||||
# the new tokenizer to use
|
||||
tokenizer = {}
|
||||
tokenizer_vocab = {}
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
|
||||
if not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||
if tokenizer_path.exists():
|
||||
tokenizer = json_read( tokenizer_path )
|
||||
else:
|
||||
tokenizer = {
|
||||
"model": {
|
||||
"vocab": get_phone_symmap()
|
||||
}
|
||||
}
|
||||
|
||||
lang_map = [
|
||||
"en",
|
||||
"ja",
|
||||
"de",
|
||||
"fr",
|
||||
"zh",
|
||||
"ko",
|
||||
]
|
||||
task_map = [
|
||||
"tts",
|
||||
"tts-c",
|
||||
"ns",
|
||||
"sr",
|
||||
"tse",
|
||||
"soe",
|
||||
"mask",
|
||||
"eoe",
|
||||
"stt",
|
||||
]
|
||||
|
||||
model_dict = {}
|
||||
# filter out the underlying model weights and extract them
|
||||
for k in state_dict['module'].keys():
|
||||
if not k.startswith('model.'):
|
||||
continue
|
||||
model_dict[k] = state_dict['module'][k].clone()
|
||||
|
||||
# cringe
|
||||
for l in range(11):
|
||||
model_dict[f'classifiers.{l}.weight'] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
for l in range(8):
|
||||
model_dict[f"embeddings.proms.{l}.weight"] = state_dict['module'][f"proms_emb.embeddings.{l}.weight"]
|
||||
for l in range(9):
|
||||
model_dict[f"embeddings.resps.{l}.weight"] = state_dict['module'][f"resps_emb.embeddings.{l}.weight"]
|
||||
|
||||
model_dict["embeddings.aux.0.weight"] = state_dict['module']["text_emb.weight"]
|
||||
model_dict["embeddings.aux.1.weight"] = state_dict['module']["rvq_l_emb.weight"]
|
||||
model_dict["embeddings.aux.2.weight"] = state_dict['module']["langs_emb.weight"]
|
||||
model_dict["embeddings.aux.3.weight"] = state_dict['module']["tasks_emb.weight"]
|
||||
model_dict["embeddings.aux.4.weight"] = state_dict['module']["len_emb.weight"]
|
||||
model_dict["embeddings.aux.5.weight"] = state_dict['module']["tones_emb.weight"]
|
||||
model_dict["embeddings.aux.6.weight"] = state_dict['module']["sep"].unsqueeze(0)
|
||||
|
||||
# write files in an HF compatible way
|
||||
out_dir = cfg.rel_path / "hf"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# write weights
|
||||
torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" )
|
||||
# write tokenizer.json
|
||||
tokenizer['model']['vocab'] |= tokenizer_vocab
|
||||
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
|
||||
# write tokenizer_config.json
|
||||
json_write({
|
||||
"added_tokens": tokenizer['added_tokens'],
|
||||
"bos_token": "<bos>",
|
||||
"eos_token": "</eos>",
|
||||
"clean_up_tokenization_spaces": True,
|
||||
"model_input_names": [
|
||||
"input_ids",
|
||||
"attention_mask"
|
||||
],
|
||||
"tokenizer_class": "PreTrainedTokenizerFast"
|
||||
}, out_dir / "tokenizer_config.json", pretty=True)
|
||||
# write config.json
|
||||
json_write({
|
||||
"architectures": [
|
||||
"ValleLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": model_dim,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": model_dim * 4,
|
||||
"max_position_embeddings": 75 * 60 * 5,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 12,
|
||||
"num_key_value_heads": 16,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 10000.0,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.0",
|
||||
"use_cache": False,
|
||||
"vocab_size": 256
|
||||
}, out_dir / "config.json", pretty=True )
|
||||
|
||||
return state_dict
|
||||
|
||||
|
@ -325,6 +412,7 @@ def main():
|
|||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--hf-llama", action='store_true', default=None) # convert to HF-style llama model
|
||||
parser.add_argument("--export-lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
|
||||
|
@ -352,8 +440,10 @@ def main():
|
|||
engines = load_engines(training=False) # to ignore loading optimizer state
|
||||
|
||||
callback = None
|
||||
if args.hf:
|
||||
callback = convert_to_hf
|
||||
if args.hf_llama:
|
||||
callback = convert_to_hf_llama
|
||||
elif args.hf:
|
||||
callback = convert_to_hf_custom
|
||||
elif args.export_lora:
|
||||
callback = extract_lora
|
||||
elif args.split_classifiers:
|
||||
|
|
|
@ -58,33 +58,30 @@ task_outputs = {
|
|||
# yuck
|
||||
def _get_offsets():
|
||||
return {
|
||||
"text": 0, # <unk>
|
||||
"quant_level": 17666, # <|RVQ:0>
|
||||
"len": 17674, # <|len:0|>
|
||||
"lang": 17686, # <|lang:en|>"
|
||||
"task": 17692, # <|task:tts|>
|
||||
"sep": 17685, # <|sep|>
|
||||
"prom": [
|
||||
256 + (1024 * 0), # <|P|0:0|>
|
||||
256 + (1024 * 1), # <|P|1:0|>
|
||||
256 + (1024 * 2), # <|P|2:0|>
|
||||
256 + (1024 * 3), # <|P|3:0|>
|
||||
256 + (1024 * 4), # <|P|4:0|>
|
||||
256 + (1024 * 5), # <|P|5:0|>
|
||||
256 + (1024 * 6), # <|P|6:0|>
|
||||
256 + (1024 * 7), # <|P|7:0|>
|
||||
],
|
||||
"resp": [
|
||||
8448, # <|AR|0:0|>
|
||||
9473, # <|NAR|0:0|>
|
||||
10498 + (1024 * 0), # <|NAR|0:1|>
|
||||
10498 + (1024 * 1), # <|NAR|1:2|>
|
||||
10498 + (1024 * 2), # <|NAR|2:3|>
|
||||
10498 + (1024 * 3), # <|NAR|3:4|>
|
||||
10498 + (1024 * 4), # <|NAR|4:5|>
|
||||
10498 + (1024 * 5), # <|NAR|5:6|>
|
||||
10498 + (1024 * 6), # <|NAR|6:7|>
|
||||
]
|
||||
"text": (0, 256),
|
||||
"quant_level": (256, 264),
|
||||
"lang": (264, 270),
|
||||
"task": (270, 279),
|
||||
"len": (279, 290),
|
||||
"tone": (290, 291),
|
||||
"sep": (291, 292),
|
||||
"prom|0": (292, 1316),
|
||||
"prom|1": (1316, 2340),
|
||||
"prom|2": (2340, 3364),
|
||||
"prom|3": (3364, 4388),
|
||||
"prom|4": (4388, 5412),
|
||||
"prom|5": (5412, 6436),
|
||||
"prom|6": (6436, 7460),
|
||||
"prom|7": (7460, 8484),
|
||||
"resps|AR:0:0": (8484, 9509),
|
||||
"resps|NAR:0:1": (9509, 10533),
|
||||
"resps|NAR:1:2": (10533, 11557),
|
||||
"resps|NAR:2:3": (11557, 12581),
|
||||
"resps|NAR:3:4": (12581, 13605),
|
||||
"resps|NAR:4:5": (13605, 14629),
|
||||
"resps|NAR:5:6": (14629, 15653),
|
||||
"resps|NAR:6:7": (15653, 16677),
|
||||
"resps|NAR:0:0": (16677, 17702),
|
||||
}
|
||||
|
||||
def _dropout_mask( input, p=None ):
|
||||
|
@ -1084,27 +1081,22 @@ class Base(nn.Module):
|
|||
classifier_level = input
|
||||
|
||||
for name, input in batch_input:
|
||||
if name not in offsets:
|
||||
continue
|
||||
|
||||
if not isinstance( input, torch.Tensor ):
|
||||
continue
|
||||
|
||||
offset = offsets[name]
|
||||
if name in ["prom", "resp"]:
|
||||
l = quant_level
|
||||
if name == "resp":
|
||||
if classifier_level == "AR:0:0":
|
||||
l = 0
|
||||
elif classifier_level == "NAR:0:0":
|
||||
l = 1
|
||||
else:
|
||||
l = 2 + (quant_level-1)
|
||||
k = name
|
||||
if name == "prom":
|
||||
k = f'prom|{quant_level}'
|
||||
elif name == "resp":
|
||||
k = f'resps|{classifier_level}'
|
||||
|
||||
offset = offset[l]
|
||||
if k not in offsets:
|
||||
continue
|
||||
|
||||
start, end = offsets[k]
|
||||
|
||||
for i, t in enumerate( input ):
|
||||
input[i] += offset * direction
|
||||
input[i] += start * direction
|
||||
|
||||
return inputs
|
||||
|
||||
|
@ -1446,45 +1438,22 @@ class Base(nn.Module):
|
|||
# offset to flattened vocab ranges
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
if name in offsets:
|
||||
offset = offsets[name]
|
||||
# yes there's a better way
|
||||
if name == "prom":
|
||||
offset = offset[quant_level]
|
||||
elif name == "resp":
|
||||
"""
|
||||
if classifier_level == "AR:0:0":
|
||||
offset = offset[0]
|
||||
elif classifier_level == "NAR:0:0":
|
||||
offset = offset[1]
|
||||
elif classifier_level == "NAR:0:1":
|
||||
offset = offset[2]
|
||||
elif classifier_level == "NAR:1:2":
|
||||
offset = offset[3]
|
||||
elif classifier_level == "NAR:2:3":
|
||||
offset = offset[4]
|
||||
elif classifier_level == "NAR:3:4":
|
||||
offset = offset[5]
|
||||
elif classifier_level == "NAR:4:5":
|
||||
offset = offset[6]
|
||||
elif classifier_level == "NAR:5:6":
|
||||
offset = offset[7]
|
||||
elif classifier_level == "NAR:6:7":
|
||||
offset = offset[8]
|
||||
else:
|
||||
continue
|
||||
"""
|
||||
if classifier_level == "AR:0:0":
|
||||
offset = offset[0]
|
||||
elif classifier_level == "NAR:0:0":
|
||||
offset = offset[1]
|
||||
else:
|
||||
offset = offset[2 + (quant_level-1)]
|
||||
|
||||
k = name
|
||||
if name == "stt":
|
||||
k = "text"
|
||||
if name == "prom":
|
||||
k = f'prom|{quant_level}'
|
||||
elif name == "resp":
|
||||
k = f'resps|{classifier_level}'
|
||||
|
||||
if k in offsets:
|
||||
start, end = offsets[k]
|
||||
|
||||
for i, t in enumerate( token ):
|
||||
if t == self.ignore_index:
|
||||
continue
|
||||
token[i] += offset
|
||||
token[i] += start
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
@ -1709,17 +1678,7 @@ class Base(nn.Module):
|
|||
if hidden_states is not None:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
|
||||
# corrections
|
||||
"""
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
if classifier_level == "len" and logits[batch_index].shape[1] > 11:
|
||||
logits[batch_index] = logits[batch_index][:,:11]
|
||||
elif classifier_level == "NAR:0:0" and logits[batch_index].shape[1] > 1024:
|
||||
logits[batch_index] = logits[batch_index][:,:1024]
|
||||
"""
|
||||
|
||||
# compute loss if the target is given
|
||||
if not training:
|
||||
loss = None
|
||||
stats = None
|
||||
|
@ -1731,32 +1690,21 @@ class Base(nn.Module):
|
|||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
# yes there's a better way
|
||||
if classifier_level == "len":
|
||||
offset = offsets["len"], 11
|
||||
elif classifier_level == "AR:0:0":
|
||||
offset = offsets["resp"][0], 1025
|
||||
elif classifier_level == "NAR:0:0":
|
||||
offset = offsets["resp"][1], 1024
|
||||
elif classifier_level == "NAR:0:1":
|
||||
offset = offsets["resp"][2], 1024
|
||||
elif classifier_level == "NAR:1:2":
|
||||
offset = offsets["resp"][3], 1024
|
||||
elif classifier_level == "NAR:2:3":
|
||||
offset = offsets["resp"][4], 1024
|
||||
elif classifier_level == "NAR:3:4":
|
||||
offset = offsets["resp"][5], 1024
|
||||
elif classifier_level == "NAR:4:5":
|
||||
offset = offsets["resp"][6], 1024
|
||||
elif classifier_level == "NAR:5:6":
|
||||
offset = offsets["resp"][7], 1024
|
||||
elif classifier_level == "NAR:6:7":
|
||||
offset = offsets["resp"][8], 1024
|
||||
if classifier_level == "stt":
|
||||
k = "text"
|
||||
elif classifier_level == "len":
|
||||
k = "len"
|
||||
else:
|
||||
k = f'resps|{classifier_level}'
|
||||
|
||||
if k not in offsets:
|
||||
continue
|
||||
|
||||
logits[batch_index] = logits[batch_index][offset[0]:offset[0]+offset[1], :]
|
||||
start, end = offsets[k]
|
||||
|
||||
logits[batch_index] = logits[batch_index][:, start:start+end]
|
||||
|
||||
# compute loss if the target is given
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
||||
|
|