sanity cleanup

This commit is contained in:
mrq 2024-12-22 15:05:45 -06:00
parent 353e478e68
commit 0d4329d2e3
5 changed files with 561 additions and 386 deletions

View File

@ -112,7 +112,7 @@
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": true,
"byte_fallback": false,
"ignore_merges": false,
"vocab": {
"<unk>": 0,

View File

@ -9,13 +9,24 @@ At the moment it's ***very*** barebones as I try and wrestle with `llama.cpp`'s
Populate `./include/` with the `llama.cpp` and `encodec.cpp` headers.
Populate `./libs/` with the compiled libraries of `llama.cpp` and `encodec.cpp`.
* `encodec.cpp` requires updating `ggml` to the latest version and doing a quick hack to make it work on the CPU backend.
* `llama.cpp` currently requires no hacks, but:
* would be *very* nice to retrieve a model's `tok_embd` through the API.
* would be ***very*** nice to only specify a slice of the output head through the API.
Run `make`.
### Required Modifications
`encodec.cpp` requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working.
`llama.cpp` *might* not require any modifications, but:
* `llm.build_vall_e` can mostly copy `llm.build_llama`, but with:
* `KQ_mask = build_inp_KQ_mask( lctx.cparams.causal_attn )`
* a unified output head (pain)
* OR adjusting the `model.output` to the correct classifier head
* OR slicing that tensor with the right range (`ggml_view_2d` confuses me)
* both require also require `*const_cast<uint32_t*>(&ctx->model.hparams.n_vocab) = output->ne[1];` because the logits are tied to `n_vocab`
* commenting out `GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());` because grabbing embeddings/classifiers require using `bid` to trick it thinking it's part of a layer
* some helper functions to retrieve the embeddings tensor from the model
* some helper functions to set the target classifier head
## To-Do
* [x] converted model to GGUF

View File

@ -11,9 +11,30 @@
#include <string>
#include <vector>
#include <array>
#include <unordered_map>
#include <iostream>
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
#if !LLAMA_CPP_EXTENDED
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
#endif
std::vector<float> read_2d_tensor( struct ggml_tensor* tensor ) {
size_t size = tensor->ne[0] * tensor->ne[1];
std::vector<float> res( size );
auto* qtype = ggml_get_type_traits(tensor->type);
// dequantize if needed
if ( ggml_is_quantized(tensor->type) ) {
qtype->to_float(tensor->data, res.data(), res.size());
} else {
memcpy( res.data(), tensor->data, res.size() * sizeof(float) );
}
return res;
}
// stores the raw inputs to be fed
struct input_t {
@ -26,94 +47,159 @@ struct input_t {
std::vector<std::vector<llama_token>> prom = {};
std::vector<std::vector<llama_token>> resp = {};
};
/*
[(0, 256), 'text_emb.weight', 'classifiers.proj.9.weight', None],
[(256, 264), 'rvq_l_emb.weight', None, '<|RVQ:{l}|>'],
[(264, 270), 'langs_emb.weight', None, '<|lang:{lang}|>'],
[(270, 279), 'tasks_emb.weight', None, '<|task:{task}|>'],
[(279, 290), 'len_emb.weight', 'classifiers.proj.10.weight', '<|len:{id}|>'],
[(290, 291), 'tones_emb.weight', None, '<|tone:{tone}|>'],
[(291, 292), 'sep', None, '<|sep|>'],
[(292, 1316), 'proms_emb.embeddings.0.weight', None, '<|P|0|{id}|>'],
[(1316, 2340), 'proms_emb.embeddings.1.weight', None, '<|P|1|{id}|>'],
[(2340, 3364), 'proms_emb.embeddings.2.weight', None, '<|P|2|{id}|>'],
[(3364, 4388), 'proms_emb.embeddings.3.weight', None, '<|P|3|{id}|>'],
[(4388, 5412), 'proms_emb.embeddings.4.weight', None, '<|P|4|{id}|>'],
[(5412, 6436), 'proms_emb.embeddings.5.weight', None, '<|P|5|{id}|>'],
[(6436, 7460), 'proms_emb.embeddings.6.weight', None, '<|P|6|{id}|>'],
[(7460, 8484), 'proms_emb.embeddings.7.weight', None, '<|P|7|{id}|>'],
[(8484, 9509), 'resps_emb.embeddings.0.weight', 'classifiers.proj.0.weight', '<|R|AR|0:0|{id}|>'],
[(9509, 10533), 'resps_emb.embeddings.1.weight', 'classifiers.proj.1.weight', '<|R|NAR|0:1|{id}|>'],
[(10533, 11557), 'resps_emb.embeddings.2.weight', 'classifiers.proj.2.weight', '<|R|NAR|1:2|{id}|>'],
[(11557, 12581), 'resps_emb.embeddings.3.weight', 'classifiers.proj.3.weight', '<|R|NAR|2:3|{id}|>'],
[(12581, 13605), 'resps_emb.embeddings.4.weight', 'classifiers.proj.4.weight', '<|R|NAR|3:4|{id}|>'],
[(13605, 14629), 'resps_emb.embeddings.5.weight', 'classifiers.proj.5.weight', '<|R|NAR|4:5|{id}|>'],
[(14629, 15653), 'resps_emb.embeddings.6.weight', 'classifiers.proj.6.weight', '<|R|NAR|5:6|{id}|>'],
[(15653, 16677), 'resps_emb.embeddings.7.weight', 'classifiers.proj.7.weight', '<|R|NAR|6:7|{id}|>'],
[(16677, 17702), 'resps_emb.embeddings.8.weight', 'classifiers.proj.8.weight', '<|R|NAR|0:0|{id}|>']
*/
// handles all the cringe logic of slicing embeddings
struct ranges_t {
std::string name;
uint32_t start;
uint32_t end;
int32_t classifier_idx = -1;
};
ranges_t io_ranges[] = {
{ "text", 0, 256, 9, },
{ "rvq_l", 256, 264, -1, },
{ "lang", 264, 270, -1, },
{ "task", 270, 279, -1, },
{ "len", 279, 290, 10, },
{ "tone", 290, 291, -1, },
{ "sep", 291, 292, -1, },
{ "prom|0", 292, 1316, -1, },
{ "prom|1", 1316, 2340, -1, },
{ "prom|2", 2340, 3364, -1, },
{ "prom|3", 3364, 4388, -1, },
{ "prom|4", 4388, 5412, -1, },
{ "prom|5", 5412, 6436, -1, },
{ "prom|6", 6436, 7460, -1, },
{ "prom|7", 7460, 8484, -1, },
{ "resps|AR:0 8484, 9509, 0,:0", },
{ "resps|NAR:0 9509, 10533, 1,:1", },
{ "resps|NAR:1: 10533, 11557, 2,2", },
{ "resps|NAR:2: 11557, 12581, 3,3", },
{ "resps|NAR:3: 12581, 13605, 4,4", },
{ "resps|NAR:4: 13605, 14629, 5,5", },
{ "resps|NAR:5: 14629, 15653, 6,6", },
{ "resps|NAR:6: 15653, 16677, 7,7", },
{ "resps|NAR:0: 16677, 17702, 8,0", },
};
struct embeddings_t {
int n_embd;
int n_vocab;
ranges_t range;
std::vector<float> embds;
};
struct embeddings_map_t {
int n_embd = 0;
int n_vocab = 0;
float* embds = NULL;
// mapping
std::unordered_map<std::string, embeddings_t> mapped_embeddings;
int text_embd_start = 0; // <unk>
int rvq_level_embd_start = 17666; // <|RVQ:0>
int len_embd_start = 17674; // <|len:0|>
int lang_embd_start = 17686; // <|lang:en|>
int task_embd_start = 17692; // <|task:tts|>
int sep_embd_start = 17685; // <|sep|>
int prom_embd_start[8] = {
256 + (1024 * 0), // <|P|0:0|>
256 + (1024 * 1), // <|P|1:0|>
256 + (1024 * 2), // <|P|2:0|>
256 + (1024 * 3), // <|P|3:0|>
256 + (1024 * 4), // <|P|4:0|>
256 + (1024 * 5), // <|P|5:0|>
256 + (1024 * 6), // <|P|6:0|>
256 + (1024 * 7), // <|P|7:0|>
};
int resp_embd_start[9] = {
8448, // <|AR|0:0|>
9473, // <|NAR|0:0|>
10498 + (1024 * 0), // <|NAR|0:1|>
10498 + (1024 * 1), // <|NAR|1:2|>
10498 + (1024 * 2), // <|NAR|2:3|>
10498 + (1024 * 3), // <|NAR|3:4|>
10498 + (1024 * 4), // <|NAR|4:5|>
10498 + (1024 * 5), // <|NAR|5:6|>
10498 + (1024 * 6), // <|NAR|6:7|>
};
float* text_embds = NULL; // &embds[text_embd_start * n_embd];
float* rvq_level_embd = NULL; // &embds[rvq_level_embd_start * n_embd];
float* len_embd = NULL; // &embds[len_embd_start * n_embd];
float* lang_embd = NULL; // &embds[lang_embd_start * n_embd];
float* task_embd = NULL; // &embds[task_embd_start * n_embd];
float* sep_embd = NULL; // &embds[sep_embd_start * n_embd];
float* prom_embds[8] = {
NULL, // &embds[prom_embd_start[0] * n_embd],
NULL, // &embds[prom_embd_start[1] * n_embd],
NULL, // &embds[prom_embd_start[2] * n_embd],
NULL, // &embds[prom_embd_start[3] * n_embd],
NULL, // &embds[prom_embd_start[4] * n_embd],
NULL, // &embds[prom_embd_start[5] * n_embd],
NULL, // &embds[prom_embd_start[6] * n_embd],
NULL, // &embds[prom_embd_start[7] * n_embd],
};
float* resps_embds[9] = {
NULL, // &embds[resp_embd_start[0] * n_embd],
NULL, // &embds[resp_embd_start[1] * n_embd],
NULL, // &embds[resp_embd_start[2] * n_embd],
NULL, // &embds[resp_embd_start[3] * n_embd],
NULL, // &embds[resp_embd_start[4] * n_embd],
NULL, // &embds[resp_embd_start[5] * n_embd],
NULL, // &embds[resp_embd_start[6] * n_embd],
NULL, // &embds[resp_embd_start[7] * n_embd],
NULL, // &embds[resp_embd_start[8] * n_embd],
};
embeddings_t( int n_embd = 0, int n_vocab = 0, float* embds = NULL ) {
init( n_embd, n_vocab, embds );
const embeddings_t& get_embeddings( const std::string& name ) {
return mapped_embeddings[name];
}
const float* get_embeddings_p( const std::string& name ) {
return mapped_embeddings[name].embds.data();
}
void init( int n_embd, int n_vocab, float* embds = NULL ) {
if ( !n_embd || !n_vocab || !embds ) return;
void init( llama_model* model ) {
this->n_embd = llama_n_embd( model );
this->n_vocab = llama_n_vocab( model );
this->n_embd = n_embd;
this->n_vocab = n_vocab;
this->embds = embds;
// to-do: figure a nicer way to do this
#if LLAMA_CPP_USE_VALL_E_ARCH
mapped_embeddings["text"] = { n_embd, 0, { "text", 0, 0, 9, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 0)) };
mapped_embeddings["rvq_l"] = { n_embd, 0, { "rvq_l", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 1)) };
mapped_embeddings["langs"] = { n_embd, 0, { "langs", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 2)) };
mapped_embeddings["tasks"] = { n_embd, 0, { "tasks", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 3)) };
mapped_embeddings["len"] = { n_embd, 0, { "len", 0, 0, 10, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 4)) };
mapped_embeddings["tones"] = { n_embd, 0, { "tones", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 5)) };
mapped_embeddings["sep"] = { n_embd, 0, { "sep", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_aux_embds(model, 6)) };
text_embds = &embds[text_embd_start * n_embd];
rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
len_embd = &embds[len_embd_start * n_embd];
lang_embd = &embds[lang_embd_start * n_embd];
task_embd = &embds[task_embd_start * n_embd];
sep_embd = &embds[sep_embd_start * n_embd];
mapped_embeddings["prom|0"] = { n_embd, 0, { "prom|0", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 0)) };
mapped_embeddings["prom|1"] = { n_embd, 0, { "prom|1", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 1)) };
mapped_embeddings["prom|2"] = { n_embd, 0, { "prom|2", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 2)) };
mapped_embeddings["prom|3"] = { n_embd, 0, { "prom|3", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 3)) };
mapped_embeddings["prom|4"] = { n_embd, 0, { "prom|4", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 4)) };
mapped_embeddings["prom|5"] = { n_embd, 0, { "prom|5", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 5)) };
mapped_embeddings["prom|6"] = { n_embd, 0, { "prom|6", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 6)) };
mapped_embeddings["prom|7"] = { n_embd, 0, { "prom|7", 0, 0, -1, }, read_2d_tensor(llama_get_vall_e_prom_embds(model, 7)) };
mapped_embeddings["resps|AR:0:0"] = { n_embd, 0, { "resps|AR:0:0", 0, 0, 0, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 0)) };
mapped_embeddings["resps|NAR:0:1"] = { n_embd, 0, { "resps|NAR:0:1", 0, 0, 1, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 1)) };
mapped_embeddings["resps|NAR:1:2"] = { n_embd, 0, { "resps|NAR:1:2", 0, 0, 2, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 2)) };
mapped_embeddings["resps|NAR:2:3"] = { n_embd, 0, { "resps|NAR:2:3", 0, 0, 3, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 3)) };
mapped_embeddings["resps|NAR:3:4"] = { n_embd, 0, { "resps|NAR:3:4", 0, 0, 4, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 4)) };
mapped_embeddings["resps|NAR:4:5"] = { n_embd, 0, { "resps|NAR:4:5", 0, 0, 5, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 5)) };
mapped_embeddings["resps|NAR:5:6"] = { n_embd, 0, { "resps|NAR:5:6", 0, 0, 6, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 6)) };
mapped_embeddings["resps|NAR:6:7"] = { n_embd, 0, { "resps|NAR:6:7", 0, 0, 7, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 7)) };
mapped_embeddings["resps|NAR:0:0"] = { n_embd, 0, { "resps|NAR:0:0", 0, 0, 8, }, read_2d_tensor(llama_get_vall_e_resp_embds(model, 8)) };
for ( auto i = 0; i < 8; ++i ) prom_embds[i] = &embds[prom_embd_start[i] * n_embd];
for ( auto i = 0; i < 9; ++i ) resps_embds[i] = &embds[resp_embd_start[i] * n_embd];
// update values
for ( auto& pair : mapped_embeddings ) {
auto& k = pair.first;
auto& v = pair.second;
auto& embds = v.embds;
v.n_vocab = embds.size() / n_embd;
v.range.end = v.n_vocab;
}
#else
#if LLAMA_CPP_EXTENDED
auto* tensor = llama_get_embedding_weights( model );
#else
auto* tensor = model->tok_embd;
#endif
// prepare slices
std::vector<float> raw_embeddings = read_2d_tensor( tensor );
for ( auto& range : io_ranges ) {
mapped_embeddings[range.name] = {
n_embd,
range.end - range.start,
range,
std::vector<float>( raw_embeddings.data() + range.start, raw_embeddings.data() + range.end )
};
}
#endif
}
};
// maps embeddings easily
std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>& tokens, int n_embd, float* embds ) {
std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>& tokens, int n_embd, const float* embds ) {
std::vector<std::vector<float>> embedded( tokens.size() );
for ( auto i = 0; i < tokens.size(); ++i ) {
embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) );
@ -123,7 +209,7 @@ std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>&
// handles adding either a token OR the embedding of that token into the batch
// this really, really helps avoid needing to abuse the tokenizer
void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} ) {
void batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} ) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
// insert raw embedding instead
@ -143,7 +229,7 @@ void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, ll
batch.logits[batch.n_tokens] = output ? 1 : 0;
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, batch.pos[batch.n_tokens], &batch.embd[batch.n_tokens], batch.logits[batch.n_tokens] );
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens-1, id, pos, embds, output );
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, pos, embds, output );
batch.n_tokens++;
}
@ -265,7 +351,7 @@ const int MAX_DURATION = 75; // * 12;
const int CTX_SIZE = 2048;
// sums embeddings over a 2D "tensor"
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) {
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM ) {
std::vector<std::vector<float>> res( input.size() );
res.resize( input[0].size() );
for ( auto& e : res ) e.resize( n_embd );
@ -274,9 +360,9 @@ std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<ll
int offset = 0;
// handles the cringe logic I have
if ( mode == EMBEDDING_MODE_RESP_AR_NAR ) {
offset = input.size() == 1 ? 0 : 2;
offset = input.size() == 1 ? 0 : 1;
} else if ( mode == EMBEDDING_MODE_RESP_NAR_LEN ) {
offset = input.size() == 1 ? 1 : 2;
offset = input.size() == 1 ? 8 : 1;
}
// get tokens
auto& tokens = input[l];
@ -293,40 +379,69 @@ std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<ll
return res;
}
void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_map, int mode ) {
void fill_batch( llama_batch& batch, input_t& input, embeddings_map_t& embeddings_map, int mode ) {
// keeps track of the position for each sequence
size_t pos = 0;
auto n_embd = embeddings_map.n_embd;
const float* text_embds = embeddings_map.get_embeddings_p("text");
const float* rvq_l_embds = embeddings_map.get_embeddings_p("rvq_l");
const float* lang_embds = embeddings_map.get_embeddings_p("lang");
const float* task_embds = embeddings_map.get_embeddings_p("task");
const float* len_embds = embeddings_map.get_embeddings_p("len");
const float* tone_embds = embeddings_map.get_embeddings_p("tone");
const float* sep_embds = embeddings_map.get_embeddings_p("sep");
const float* prom_embds[] = {
embeddings_map.get_embeddings_p("prom|0"),
embeddings_map.get_embeddings_p("prom|1"),
embeddings_map.get_embeddings_p("prom|2"),
embeddings_map.get_embeddings_p("prom|3"),
embeddings_map.get_embeddings_p("prom|4"),
embeddings_map.get_embeddings_p("prom|5"),
embeddings_map.get_embeddings_p("prom|6"),
embeddings_map.get_embeddings_p("prom|7"),
};
const float* resp_embds[] = {
embeddings_map.get_embeddings_p("resps|AR:0:0"),
embeddings_map.get_embeddings_p("resps|NAR:0:1"),
embeddings_map.get_embeddings_p("resps|NAR:1:2"),
embeddings_map.get_embeddings_p("resps|NAR:2:3"),
embeddings_map.get_embeddings_p("resps|NAR:3:4"),
embeddings_map.get_embeddings_p("resps|NAR:4:5"),
embeddings_map.get_embeddings_p("resps|NAR:5:6"),
embeddings_map.get_embeddings_p("resps|NAR:6:7"),
embeddings_map.get_embeddings_p("resps|NAR:0:0"),
};
// insert text tokens
for ( auto& id : input.phn ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
for ( auto& id : input.phn ) batch_add( batch, id, n_embd, text_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
pos = 0;
// insert lang token
batch_add( batch, input.lang, n_embd, embeddings_map.lang_embd, pos++, false );
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
batch_add( batch, input.lang, n_embd, lang_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
pos = 0;
// insert rvq level token
batch_add( batch, input.rvq_l, n_embd, embeddings_map.rvq_level_embd, pos++, false );
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
batch_add( batch, input.rvq_l, n_embd, rvq_l_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
pos = 0;
// insert prom tokens
auto summed_proms_embds = sum_embeddings( input.prom, n_embd, input.rvq_l, embeddings_map.prom_embds );
auto summed_proms_embds = sum_embeddings( input.prom, n_embd, input.rvq_l, prom_embds );
for ( auto i = 0; i < summed_proms_embds.size(); ++i ) {
batch_add( batch, -1, n_embd, &summed_proms_embds[i][0], pos++, false );
}
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, mode == INFERENCE_MODE_AR ); // set as the last logit if AR
batch_add( batch, 0, n_embd, sep_embds, pos++, mode == INFERENCE_MODE_AR ); // set as the last logit if AR
pos = 0;
// input starting len token
if ( input.task == "len" ) {
batch_add( batch, 0, n_embd, embeddings_map.len_embd, pos++, true );
batch_add( batch, 0, n_embd, len_embds, pos++, true );
pos = 0;
}
// insert resp tokens
if ( !input.resp.empty() ) {
auto summed_resps_embds = sum_embeddings( input.resp, n_embd, input.rvq_l, embeddings_map.resps_embds, mode == INFERENCE_MODE_AR ? EMBEDDING_MODE_RESP_AR_NAR : EMBEDDING_MODE_RESP_NAR_LEN );
auto summed_resps_embds = sum_embeddings( input.resp, n_embd, input.rvq_l, resp_embds, mode == INFERENCE_MODE_AR ? EMBEDDING_MODE_RESP_AR_NAR : EMBEDDING_MODE_RESP_NAR_LEN );
for ( auto i = 0; i < summed_resps_embds.size(); ++i ) {
batch_add( batch, -1, n_embd, &summed_resps_embds[i][0], pos++, true );
}
@ -335,7 +450,7 @@ void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_ma
}
// generation code, should handle all modalities easily
std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) {
std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_map_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) {
llama_batch batch = llama_batch_init( 22500, embeddings_map.n_embd, 22500 );
// Decoding loop
@ -364,34 +479,50 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
return {};
}
float* embds = NULL;
int logit_range[2];
if ( mode == INFERENCE_MODE_AR ) {
logit_range[0] = embeddings_map.resp_embd_start[0];
logit_range[1] = embeddings_map.resp_embd_start[1];
embds = embeddings_map.resps_embds[0];
stop_token = embeddings_map.resp_embd_start[1] - 1; // <|AR|0:STOP|>
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
logit_range[0] = embeddings_map.resp_embd_start[1];
logit_range[1] = embeddings_map.resp_embd_start[2];
embds = embeddings_map.resps_embds[1];
const float* embds = NULL;
ranges_t range;
stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|>
if ( mode == INFERENCE_MODE_AR ) {
auto& embeddings = embeddings_map.get_embeddings("resps|AR:0:0");
range = embeddings.range;
embds = embeddings.embds.data();
stop_token = range.end - 1;
// llama_set_classifier_index( ctx, 0 );
printf("Generating in %s mode (%i:%i)\n", "AR", range.start, range.end);
} else if ( mode == INFERENCE_MODE_NAR ) {
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l-1];
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l-1];
std::string k_embds[] = {
"resps|NAR:0:0", // invalid
"resps|NAR:0:1",
"resps|NAR:1:2",
"resps|NAR:2:3",
"resps|NAR:3:4",
"resps|NAR:4:5",
"resps|NAR:5:6",
"resps|NAR:6:7",
};
auto& embeddings = embeddings_map.get_embeddings(k_embds[rvq_l]);
range = embeddings.range;
embds = embeddings.embds.data();
embds = embeddings_map.resps_embds[2];
// llama_set_classifier_index( ctx, rvq_l );
printf("Generating in %s mode (%i:%i)\n", "NAR", range.start, range.end);
} else if ( mode == INFERENCE_MODE_LEN ) {
logit_range[0] = embeddings_map.len_embd_start;
logit_range[1] = embeddings_map.len_embd_start + 11;
auto& embeddings = embeddings_map.get_embeddings("len");
range = embeddings.range;
embds = embeddings.embds.data();
stop_token = range.end - 1;
embds = embeddings_map.len_embd;
stop_token = embeddings_map.len_embd_start + 10;
// llama_set_classifier_index( ctx, 10 );
printf("Generating in %s mode (%i:%i)\n", "len", range.start, range.end);
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
auto& embeddings = embeddings_map.get_embeddings("NAR:0:0");
range = embeddings.range;
embds = embeddings.embds.data();
// llama_set_classifier_index( ctx, 8 );
printf("Generating in %s mode (%i:%i)\n", "NAR-len", range.start, range.end);
}
llama_set_causal_attn( ctx, n_logits == 1 );
@ -408,32 +539,34 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
for ( auto i = n_logits; i > 0; --i ) {
// filter logits
auto* logits = llama_get_logits_ith( ctx, -i );
// ensures only tokens within our designated range are used
#if !LLAMA_CPP_USE_VALL_E_ARCH
for ( auto i = 0; i < embeddings_map.n_vocab; ++i ) {
// out of target logit range, set to never happen
if ( i < logit_range[0] || i >= logit_range[1] ) logits[i] = -INFINITY;
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
}
#endif
// sample the next token
auto t = llama_sampler_sample(smpl, ctx, -i);
if ( verbose ) {
// print token for debugging
char buf[256];
int n = llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
if ( n < 256 ) buf[n] = '\0';
printf("%s\n", buf );
}
// offset back into range
#if !LLAMA_CPP_USE_VALL_E_ARCH
t -= range.start;
#endif
printf("%i: %i\n", n_decode, t);
// is stop token
if ( t == stop_token ) {
printf("STOPPED\n");
max_tokens = 0;
break;
}
// offset into range
t -= logit_range[0];
// store token
output_tokens.emplace_back(t);
// update batch with token
batch_add( batch, t, embeddings_map.n_embd, embds, output_tokens.size(), true );
}
}
@ -460,7 +593,7 @@ int main(int argc, char ** argv) {
int32_t ngl = 0;
int modality = MODALITY_NAR_LEN;
input_t input{};
embeddings_t embeddings_map{};
embeddings_map_t embeddings_map{};
// input.phonemes = "hˈɛloː ʋˈɔrlt";
input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
@ -473,19 +606,6 @@ int main(int argc, char ** argv) {
// load dynamic backends
ggml_backend_load_all();
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
if (!ectx) {
printf("%s: error during loading model\n", __func__);
return 1;
}
encodec_set_target_bandwidth(ectx, 6);
encodec_set_sample_rate(ectx, 24000);
// load wavform
input.prom = encode_audio_from_disk(ectx, input_prompt_path);
//input.resp = encode_audio_from_disk(ectx, output_response_path);
// initialize the models
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;
@ -524,19 +644,25 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
struct encodec_context* ectx = encodec_load_model(encodec_model_path.c_str(), 0, ngl);
if (!ectx) {
printf("%s: error during loading model\n", __func__);
return 1;
}
encodec_set_target_bandwidth(ectx, 6);
encodec_set_sample_rate(ectx, 24000);
// load wavform
input.prom = encode_audio_from_disk(ectx, input_prompt_path);
//input.resp = encode_audio_from_disk(ectx, output_response_path);
// prepare batch
auto n_embd = llama_n_embd( model );
auto n_vocab = llama_n_vocab( model );
// grab input embeddings
std::vector<float> embds( n_embd * n_vocab );
auto* qtype = ggml_get_type_traits(model->tok_embd->type);
// dequantize if needed
if ( ggml_is_quantized(model->tok_embd->type) ) {
qtype->to_float(model->tok_embd->data, embds.data(), embds.size());
}
// update mapping
embeddings_map.init( n_embd, n_vocab, embds.data() );
embeddings_map.init( model );
// tokenize phonemes
// to-do: make this work, the vocab does not work
@ -573,7 +699,7 @@ int main(int argc, char ** argv) {
// fill with mask tokens
input.resp.resize(1);
for ( auto i = 0; i < len; ++i ) {
input.resp[0].emplace_back( embeddings_map.resp_embd_start[3] - 1 ); // fill with masked tokens
input.resp[0].emplace_back( 1024 ); // fill with masked tokens
}
// inference NAR-len 0

View File

@ -12,7 +12,7 @@ from .utils.io import torch_save, torch_load, json_read, json_write, Path
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
@torch.no_grad()
def convert_to_hf( state_dict, config = None, save_path = None ):
def convert_to_hf_llama( state_dict, config = None, save_path = None ):
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
@ -21,6 +21,9 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
# the new tokenizer to use
tokenizer = {}
tokenizer_vocab = {}
@ -37,20 +40,6 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
}
}
l_tokens = [
n_text_tokens, # text
n_audio_tokens * n_resp_levels, # prom
(n_audio_tokens + 1) * 2, # resp: AR + NAR-len (with stop/mask)
(n_audio_tokens) * (n_resp_levels - 1), # NAR
n_resp_levels, # RVQ level
n_len_tokens, # len tokens
1, # separator
n_lang_tokens, # langs
n_task_tokens, # tasks
]
n_tokens = sum(l_tokens)
lang_map = [
"en",
"ja",
@ -70,9 +59,47 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
"eoe",
"stt",
]
tone_map = [
"neutral",
]
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
# (start, end), embedding, classifier, token_format
mapping = [
[(0, 0), "text_emb.weight", "classifiers.proj.9.weight", None],
[(0, 0), "rvq_l_emb.weight", None, "<|RVQ:{l}|>"],
[(0, 0), "langs_emb.weight", None, "<|lang:{lang}|>"],
[(0, 0), "tasks_emb.weight", None, "<|task:{task}|>"],
[(0, 0), "len_emb.weight", "classifiers.proj.10.weight", "<|len:{id}|>"],
[(0, 0), "tones_emb.weight", None, "<|tone:{tone}|>"],
[(0, 0), "sep", None, "<|sep|>"],
[(0, 0), "proms_emb.embeddings.0.weight", None, "<|P|0|{id}|>"],
[(0, 0), "proms_emb.embeddings.1.weight", None, "<|P|1|{id}|>"],
[(0, 0), "proms_emb.embeddings.2.weight", None, "<|P|2|{id}|>"],
[(0, 0), "proms_emb.embeddings.3.weight", None, "<|P|3|{id}|>"],
[(0, 0), "proms_emb.embeddings.4.weight", None, "<|P|4|{id}|>"],
[(0, 0), "proms_emb.embeddings.5.weight", None, "<|P|5|{id}|>"],
[(0, 0), "proms_emb.embeddings.6.weight", None, "<|P|6|{id}|>"],
[(0, 0), "proms_emb.embeddings.7.weight", None, "<|P|7|{id}|>"],
[(0, 0), "resps_emb.embeddings.0.weight", "classifiers.proj.0.weight", "<|R|AR|0:0|{id}|>"],
[(0, 0), "resps_emb.embeddings.1.weight", "classifiers.proj.1.weight", "<|R|NAR|0:1|{id}|>"],
[(0, 0), "resps_emb.embeddings.2.weight", "classifiers.proj.2.weight", "<|R|NAR|1:2|{id}|>"],
[(0, 0), "resps_emb.embeddings.3.weight", "classifiers.proj.3.weight", "<|R|NAR|2:3|{id}|>"],
[(0, 0), "resps_emb.embeddings.4.weight", "classifiers.proj.4.weight", "<|R|NAR|3:4|{id}|>"],
[(0, 0), "resps_emb.embeddings.5.weight", "classifiers.proj.5.weight", "<|R|NAR|4:5|{id}|>"],
[(0, 0), "resps_emb.embeddings.6.weight", "classifiers.proj.6.weight", "<|R|NAR|5:6|{id}|>"],
[(0, 0), "resps_emb.embeddings.7.weight", "classifiers.proj.7.weight", "<|R|NAR|6:7|{id}|>"],
[(0, 0), "resps_emb.embeddings.8.weight", "classifiers.proj.8.weight", "<|R|NAR|0:0|{id}|>"],
]
n_tokens = 0
# to-do: figure out discrepancy
for i, m in enumerate( mapping ):
k_embd = mapping[i][1]
embds = state_dict['module'][k_embd] if k_embd in state_dict['module'] else None
n_tokens += 1 if embds.dim() == 1 else embds.shape[0]
embedding = torch.nn.Embedding( n_tokens, model_dim )
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
@ -80,113 +107,49 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
if not split_classifiers:
classifier.weight[:] = state_dict['module']['classifier.weight'][:]
# to-do: ignore classifier for RVQ level 7
# update ranges
start = 0
for i, m in enumerate( mapping ):
# get previous start
k_embd = mapping[i][1]
k_head = mapping[i][2]
token_format = mapping[i][3]
# inject text tokens
token_start = 0
token_end = l_tokens[0]
embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight']
if split_classifiers:
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
# tokenizer already has these tokens
embds = state_dict['module'][k_embd] if k_embd in state_dict['module'] else None
head = state_dict['module'][k_head] if k_head in state_dict['module'] else None
# inject prom tokens
token_start = token_end
token_end += l_tokens[1]
for l in range(n_resp_levels):
start = token_start + (l * n_audio_tokens)
end = start + n_audio_tokens
embedding.weight[start:end] = state_dict['module'][f'proms_emb.embeddings.{l}.weight']
# there's no corresponding classifier
#classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
#classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|P|{l}:{t}|>'] = start + t
# expand if 1D
if embds.dim() == 1:
embds = embds.unsqueeze(0)
# inject AR
token_start = token_end
token_end += l_tokens[2] // 2
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight']
if split_classifiers:
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<AR|0:0|STOP|>'] = token_start + 1024
tokens = embds.shape[0]
# inject NAR-len
token_start = token_end
token_end += l_tokens[2] // 2
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
if split_classifiers:
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
if classifier_bias:
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|NAR|0:0|{t}|>'] = token_start + t
tokenizer_vocab[f'<|NAR|0:0|STOP|>'] = token_start + 1024
# inject NAR
token_start = token_end
token_end += l_tokens[3]
for l in range(1, n_resp_levels):
start = token_start + ((l-1) * n_audio_tokens)
end = start + n_audio_tokens
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
if split_classifiers:
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
if classifier_bias:
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
for t in range(n_audio_tokens):
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
# inject RVQ level
token_start = token_end
token_end += l_tokens[4]
embedding.weight[token_start:token_end] = state_dict['module'][f'rvq_l_emb.weight']
# there is no corresponding classifier
for l in range(n_resp_levels):
tokenizer_vocab[f'<|RVQ:{l}|>'] = token_start + l
# inject len
token_start = token_end
token_end += l_tokens[5]
embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight']
if split_classifiers:
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
if classifier_bias:
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
for t in range(n_len_tokens):
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
# inject sep
token_start = token_end
token_end += l_tokens[6]
embedding.weight[token_start:token_end] = state_dict['module']['sep']
tokenizer_vocab['<|sep|>'] = token_start
# there is no corresponding classifier
# inject langs
token_start = token_end
token_end += l_tokens[7]
embedding.weight[token_start:token_end] = state_dict['module']['langs_emb.weight']
for l in range(n_lang_tokens):
lang = lang_map[l]
tokenizer_vocab[f'<|lang:{lang}|>'] = token_start + l
# there is no corresponding classifier
# inject tasks
token_start = token_end
token_end += l_tokens[8]
embedding.weight[token_start:token_end] = state_dict['module']['tasks_emb.weight']
for l in range(n_task_tokens):
task = task_map[l]
tokenizer_vocab[f'<|task:{task}|>'] = token_start + l
# there is no corresponding classifier
if embds is not None:
embedding.weight[start:start+tokens] = embds
if split_classifiers and head is not None:
classifier.weight[start:start+head.shape[0]] = head
if token_format is not None:
for idx in range(0, tokens):
# RVQ level
if "{l}" in token_format:
token = token_format.format(l=idx)
elif "{lang}" in token_format:
token = token_format.format(lang=lang_map[idx])
elif "{task}" in token_format:
token = token_format.format(task=task_map[idx])
elif "{tone}" in token_format:
token = token_format.format(tone=tone_map[idx])
elif "{id}" in token_format:
token = token_format.format(id=idx)
else:
token = token_format
tokenizer_vocab[token] = idx + start
end = start + tokens
mapping[i][0] = (start, end)
start = end
model_dict = {}
# filter out the underlying model weights and extract them
@ -225,7 +188,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
# write config.json
json_write({
"architectures": [
"LlamaForCausalLM"
"LLaMAForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
@ -251,6 +214,130 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
"vocab_size": n_tokens
}, out_dir / "config.json", pretty=True )
return state_dict
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
@torch.no_grad()
def convert_to_hf_custom( state_dict, config = None, save_path = None ):
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0]
n_len_tokens = 11
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
# the new tokenizer to use
tokenizer = {}
tokenizer_vocab = {}
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
if not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
if tokenizer_path.exists():
tokenizer = json_read( tokenizer_path )
else:
tokenizer = {
"model": {
"vocab": get_phone_symmap()
}
}
lang_map = [
"en",
"ja",
"de",
"fr",
"zh",
"ko",
]
task_map = [
"tts",
"tts-c",
"ns",
"sr",
"tse",
"soe",
"mask",
"eoe",
"stt",
]
model_dict = {}
# filter out the underlying model weights and extract them
for k in state_dict['module'].keys():
if not k.startswith('model.'):
continue
model_dict[k] = state_dict['module'][k].clone()
# cringe
for l in range(11):
model_dict[f'classifiers.{l}.weight'] = state_dict['module'][f'classifiers.proj.{l}.weight']
for l in range(8):
model_dict[f"embeddings.proms.{l}.weight"] = state_dict['module'][f"proms_emb.embeddings.{l}.weight"]
for l in range(9):
model_dict[f"embeddings.resps.{l}.weight"] = state_dict['module'][f"resps_emb.embeddings.{l}.weight"]
model_dict["embeddings.aux.0.weight"] = state_dict['module']["text_emb.weight"]
model_dict["embeddings.aux.1.weight"] = state_dict['module']["rvq_l_emb.weight"]
model_dict["embeddings.aux.2.weight"] = state_dict['module']["langs_emb.weight"]
model_dict["embeddings.aux.3.weight"] = state_dict['module']["tasks_emb.weight"]
model_dict["embeddings.aux.4.weight"] = state_dict['module']["len_emb.weight"]
model_dict["embeddings.aux.5.weight"] = state_dict['module']["tones_emb.weight"]
model_dict["embeddings.aux.6.weight"] = state_dict['module']["sep"].unsqueeze(0)
# write files in an HF compatible way
out_dir = cfg.rel_path / "hf"
out_dir.mkdir(parents=True, exist_ok=True)
# write weights
torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" )
# write tokenizer.json
tokenizer['model']['vocab'] |= tokenizer_vocab
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
# write tokenizer_config.json
json_write({
"added_tokens": tokenizer['added_tokens'],
"bos_token": "<bos>",
"eos_token": "</eos>",
"clean_up_tokenization_spaces": True,
"model_input_names": [
"input_ids",
"attention_mask"
],
"tokenizer_class": "PreTrainedTokenizerFast"
}, out_dir / "tokenizer_config.json", pretty=True)
# write config.json
json_write({
"architectures": [
"ValleLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": model_dim,
"initializer_range": 0.02,
"intermediate_size": model_dim * 4,
"max_position_embeddings": 75 * 60 * 5,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 12,
"num_key_value_heads": 16,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0",
"use_cache": False,
"vocab_size": 256
}, out_dir / "config.json", pretty=True )
return state_dict
@ -325,6 +412,7 @@ def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--hf-llama", action='store_true', default=None) # convert to HF-style llama model
parser.add_argument("--export-lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
@ -352,8 +440,10 @@ def main():
engines = load_engines(training=False) # to ignore loading optimizer state
callback = None
if args.hf:
callback = convert_to_hf
if args.hf_llama:
callback = convert_to_hf_llama
elif args.hf:
callback = convert_to_hf_custom
elif args.export_lora:
callback = extract_lora
elif args.split_classifiers:

View File

@ -58,33 +58,30 @@ task_outputs = {
# yuck
def _get_offsets():
return {
"text": 0, # <unk>
"quant_level": 17666, # <|RVQ:0>
"len": 17674, # <|len:0|>
"lang": 17686, # <|lang:en|>"
"task": 17692, # <|task:tts|>
"sep": 17685, # <|sep|>
"prom": [
256 + (1024 * 0), # <|P|0:0|>
256 + (1024 * 1), # <|P|1:0|>
256 + (1024 * 2), # <|P|2:0|>
256 + (1024 * 3), # <|P|3:0|>
256 + (1024 * 4), # <|P|4:0|>
256 + (1024 * 5), # <|P|5:0|>
256 + (1024 * 6), # <|P|6:0|>
256 + (1024 * 7), # <|P|7:0|>
],
"resp": [
8448, # <|AR|0:0|>
9473, # <|NAR|0:0|>
10498 + (1024 * 0), # <|NAR|0:1|>
10498 + (1024 * 1), # <|NAR|1:2|>
10498 + (1024 * 2), # <|NAR|2:3|>
10498 + (1024 * 3), # <|NAR|3:4|>
10498 + (1024 * 4), # <|NAR|4:5|>
10498 + (1024 * 5), # <|NAR|5:6|>
10498 + (1024 * 6), # <|NAR|6:7|>
]
"text": (0, 256),
"quant_level": (256, 264),
"lang": (264, 270),
"task": (270, 279),
"len": (279, 290),
"tone": (290, 291),
"sep": (291, 292),
"prom|0": (292, 1316),
"prom|1": (1316, 2340),
"prom|2": (2340, 3364),
"prom|3": (3364, 4388),
"prom|4": (4388, 5412),
"prom|5": (5412, 6436),
"prom|6": (6436, 7460),
"prom|7": (7460, 8484),
"resps|AR:0:0": (8484, 9509),
"resps|NAR:0:1": (9509, 10533),
"resps|NAR:1:2": (10533, 11557),
"resps|NAR:2:3": (11557, 12581),
"resps|NAR:3:4": (12581, 13605),
"resps|NAR:4:5": (13605, 14629),
"resps|NAR:5:6": (14629, 15653),
"resps|NAR:6:7": (15653, 16677),
"resps|NAR:0:0": (16677, 17702),
}
def _dropout_mask( input, p=None ):
@ -1084,27 +1081,22 @@ class Base(nn.Module):
classifier_level = input
for name, input in batch_input:
if name not in offsets:
continue
if not isinstance( input, torch.Tensor ):
continue
offset = offsets[name]
if name in ["prom", "resp"]:
l = quant_level
if name == "resp":
if classifier_level == "AR:0:0":
l = 0
elif classifier_level == "NAR:0:0":
l = 1
else:
l = 2 + (quant_level-1)
k = name
if name == "prom":
k = f'prom|{quant_level}'
elif name == "resp":
k = f'resps|{classifier_level}'
offset = offset[l]
if k not in offsets:
continue
start, end = offsets[k]
for i, t in enumerate( input ):
input[i] += offset * direction
input[i] += start * direction
return inputs
@ -1446,45 +1438,22 @@ class Base(nn.Module):
# offset to flattened vocab ranges
if self.classifier is not None:
offsets = _get_offsets()
if name in offsets:
offset = offsets[name]
# yes there's a better way
if name == "prom":
offset = offset[quant_level]
elif name == "resp":
"""
if classifier_level == "AR:0:0":
offset = offset[0]
elif classifier_level == "NAR:0:0":
offset = offset[1]
elif classifier_level == "NAR:0:1":
offset = offset[2]
elif classifier_level == "NAR:1:2":
offset = offset[3]
elif classifier_level == "NAR:2:3":
offset = offset[4]
elif classifier_level == "NAR:3:4":
offset = offset[5]
elif classifier_level == "NAR:4:5":
offset = offset[6]
elif classifier_level == "NAR:5:6":
offset = offset[7]
elif classifier_level == "NAR:6:7":
offset = offset[8]
else:
continue
"""
if classifier_level == "AR:0:0":
offset = offset[0]
elif classifier_level == "NAR:0:0":
offset = offset[1]
else:
offset = offset[2 + (quant_level-1)]
k = name
if name == "stt":
k = "text"
if name == "prom":
k = f'prom|{quant_level}'
elif name == "resp":
k = f'resps|{classifier_level}'
if k in offsets:
start, end = offsets[k]
for i, t in enumerate( token ):
if t == self.ignore_index:
continue
token[i] += offset
token[i] += start
if token.is_floating_point():
ignored = True
@ -1709,17 +1678,7 @@ class Base(nn.Module):
if hidden_states is not None:
for i, state in enumerate( hidden_states ):
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# corrections
"""
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level == "len" and logits[batch_index].shape[1] > 11:
logits[batch_index] = logits[batch_index][:,:11]
elif classifier_level == "NAR:0:0" and logits[batch_index].shape[1] > 1024:
logits[batch_index] = logits[batch_index][:,:1024]
"""
# compute loss if the target is given
if not training:
loss = None
stats = None
@ -1731,32 +1690,21 @@ class Base(nn.Module):
if self.classifier is not None:
offsets = _get_offsets()
for batch_index, classifier_level in enumerate( classifier_levels ):
# yes there's a better way
if classifier_level == "len":
offset = offsets["len"], 11
elif classifier_level == "AR:0:0":
offset = offsets["resp"][0], 1025
elif classifier_level == "NAR:0:0":
offset = offsets["resp"][1], 1024
elif classifier_level == "NAR:0:1":
offset = offsets["resp"][2], 1024
elif classifier_level == "NAR:1:2":
offset = offsets["resp"][3], 1024
elif classifier_level == "NAR:2:3":
offset = offsets["resp"][4], 1024
elif classifier_level == "NAR:3:4":
offset = offsets["resp"][5], 1024
elif classifier_level == "NAR:4:5":
offset = offsets["resp"][6], 1024
elif classifier_level == "NAR:5:6":
offset = offsets["resp"][7], 1024
elif classifier_level == "NAR:6:7":
offset = offsets["resp"][8], 1024
if classifier_level == "stt":
k = "text"
elif classifier_level == "len":
k = "len"
else:
k = f'resps|{classifier_level}'
if k not in offsets:
continue
logits[batch_index] = logits[batch_index][offset[0]:offset[0]+offset[1], :]
start, end = offsets[k]
logits[batch_index] = logits[batch_index][:, start:start+end]
# compute loss if the target is given
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )