more updates to vall_e.cpp
This commit is contained in:
parent
503124d0d3
commit
1b4a69ce29
|
@ -24,11 +24,13 @@ Run `make`.
|
|||
* [x] juggle the output head / classifier properly
|
||||
* [ ] phonemize text
|
||||
* [ ] tokenize phonemes
|
||||
* [ ] load audio from disk
|
||||
* [ ] encode audio
|
||||
* [ ] sum embeddings for the `prom` and prior `resp`s
|
||||
* [x] load audio from disk
|
||||
* [x] encode audio
|
||||
* [x] sum embeddings for the `prom` and prior `resp`s
|
||||
* [x] `AR` sampling
|
||||
* [ ] `NAR-len` demasking sampling
|
||||
* [ ] `NAR` sampling
|
||||
* [ ] decode audio to disk
|
||||
* [ ] a functional CLI
|
||||
* [ ] actually make it work
|
||||
* it seems naively stitching the model together isn't good enough since the output is wrong
|
|
@ -13,151 +13,140 @@
|
|||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
/* Begin cringe so I can access the model's tok_embd */
|
||||
// it needs to be copied so the struct layout is exactly as it is under llama.cpp
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
|
||||
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
|
||||
|
||||
enum e_model {
|
||||
MODEL_UNKNOWN,
|
||||
// stores the raw inputs to be fed
|
||||
struct input_t {
|
||||
std::string task = "tts";
|
||||
|
||||
std::vector<llama_token> phonemes = {};
|
||||
llama_token lang = 0;
|
||||
llama_token rvq_l = 0;
|
||||
std::vector<std::vector<llama_token>> prom = {};
|
||||
std::vector<std::vector<llama_token>> resp = {};
|
||||
};
|
||||
// handles all the cringe logic of slicing embeddings
|
||||
struct embeddings_t {
|
||||
int n_embd = 0;
|
||||
int n_vocab = 0;
|
||||
float* embds = NULL;
|
||||
|
||||
int text_embd_start = 0; // <unk>
|
||||
int rvq_level_embd_start = 17666; // <|RVQ:0>
|
||||
int len_embd_start = 17674; // <|len:0|>
|
||||
int lang_embd_start = 17686; // <|lang:en|>
|
||||
int task_embd_start = 17692; // <|task:tts|>
|
||||
int sep_embd_start = 17685; // <|sep|>
|
||||
int prom_embd_start[8] = {
|
||||
256 + (1024 * 0), // <|P|0:0|>
|
||||
256 + (1024 * 1), // <|P|1:0|>
|
||||
256 + (1024 * 2), // <|P|2:0|>
|
||||
256 + (1024 * 3), // <|P|3:0|>
|
||||
256 + (1024 * 4), // <|P|4:0|>
|
||||
256 + (1024 * 5), // <|P|5:0|>
|
||||
256 + (1024 * 6), // <|P|6:0|>
|
||||
256 + (1024 * 7), // <|P|7:0|>
|
||||
};
|
||||
int resp_embd_start[9] = {
|
||||
8448, // <|AR|0:0|>
|
||||
9473, // <|NAR|0:0|>
|
||||
10498 + (1024 * 0), // <|NAR|0:1|>
|
||||
10498 + (1024 * 1), // <|NAR|1:2|>
|
||||
10498 + (1024 * 2), // <|NAR|2:3|>
|
||||
10498 + (1024 * 3), // <|NAR|3:4|>
|
||||
10498 + (1024 * 4), // <|NAR|4:5|>
|
||||
10498 + (1024 * 5), // <|NAR|5:6|>
|
||||
10498 + (1024 * 6), // <|NAR|6:7|>
|
||||
};
|
||||
|
||||
float* text_embds = NULL; // &embds[text_embd_start * n_embd];
|
||||
float* rvq_level_embd = NULL; // &embds[rvq_level_embd_start * n_embd];
|
||||
float* len_embd = NULL; // &embds[len_embd_start * n_embd];
|
||||
float* lang_embd = NULL; // &embds[lang_embd_start * n_embd];
|
||||
float* task_embd = NULL; // &embds[task_embd_start * n_embd];
|
||||
float* sep_embd = NULL; // &embds[sep_embd_start * n_embd];
|
||||
|
||||
float* prom_embds[8] = {
|
||||
NULL, // &embds[prom_embd_start[0] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[1] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[2] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[3] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[4] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[5] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[6] * n_embd],
|
||||
NULL, // &embds[prom_embd_start[7] * n_embd],
|
||||
};
|
||||
float* resps_embds[9] = {
|
||||
NULL, // &embds[resp_embd_start[0] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[1] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[2] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[3] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[4] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[5] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[6] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[7] * n_embd],
|
||||
NULL, // &embds[resp_embd_start[8] * n_embd],
|
||||
};
|
||||
|
||||
embeddings_t( int n_embd = 0, int n_vocab = 0, float* embds = NULL ) {
|
||||
init( n_embd, n_vocab, embds );
|
||||
}
|
||||
|
||||
void init( int n_embd, int n_vocab, float* embds = NULL ) {
|
||||
if ( !n_embd || !n_vocab || !embds ) return;
|
||||
|
||||
this->n_embd = n_embd;
|
||||
this->n_vocab = n_vocab;
|
||||
this->embds = embds;
|
||||
|
||||
text_embds = &embds[text_embd_start * n_embd];
|
||||
rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
|
||||
len_embd = &embds[len_embd_start * n_embd];
|
||||
lang_embd = &embds[lang_embd_start * n_embd];
|
||||
task_embd = &embds[task_embd_start * n_embd];
|
||||
sep_embd = &embds[sep_embd_start * n_embd];
|
||||
|
||||
for ( auto i = 0; i < 8; ++i ) prom_embds[i] = &embds[prom_embd_start[i] * n_embd];
|
||||
for ( auto i = 0; i < 9; ++i ) resps_embds[i] = &embds[resp_embd_start[i] * n_embd];
|
||||
}
|
||||
};
|
||||
|
||||
enum llm_arch {
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
struct llama_hparams_posnet {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
};
|
||||
|
||||
struct llama_hparams_convnext {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
};
|
||||
|
||||
struct llama_hparams {
|
||||
bool vocab_only;
|
||||
bool rope_finetuned;
|
||||
bool use_par_res;
|
||||
bool swin_norm;
|
||||
|
||||
uint32_t n_vocab = 0;
|
||||
uint32_t n_ctx_train; // context size the model was trained on
|
||||
uint32_t n_embd;
|
||||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_swa = 0; // sliding window attention (SWA)
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_vocab_type = 0; // for BERT-style token types
|
||||
uint32_t n_rel_attn_bkts = 0;
|
||||
|
||||
// for WavTokenizer
|
||||
struct llama_hparams_posnet posnet;
|
||||
struct llama_hparams_convnext convnext;
|
||||
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||
|
||||
uint32_t n_layer_dense_lead = 0;
|
||||
uint32_t n_lora_q = 0;
|
||||
uint32_t n_lora_kv = 0;
|
||||
uint32_t n_ff_exp = 0;
|
||||
uint32_t n_ff_shexp = 0;
|
||||
uint32_t n_expert_shared = 0;
|
||||
float expert_weights_scale = 0.0;
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
float f_norm_group_eps;
|
||||
|
||||
uint32_t n_norm_groups;
|
||||
|
||||
float f_attn_logit_softcapping = 50.0f;
|
||||
float f_final_logit_softcapping = 30.0f;
|
||||
|
||||
// for RWKV
|
||||
uint32_t rescale_every_n_layers = 0;
|
||||
uint32_t time_mix_extra_dim = 0;
|
||||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul;
|
||||
int rope_sections[4];
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
uint32_t ssm_d_inner = 0;
|
||||
uint32_t ssm_d_state = 0;
|
||||
uint32_t ssm_dt_rank = 0;
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
float f_logit_scale = 0.0f;
|
||||
|
||||
// Additional scale factors (Granite/Granite MoE)
|
||||
float f_residual_scale = 0.0f;
|
||||
float f_embedding_scale = 0.0f;
|
||||
float f_attention_scale = 0.0f;
|
||||
|
||||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
};
|
||||
|
||||
struct llama_model {
|
||||
e_model type = MODEL_UNKNOWN;
|
||||
llm_arch arch = LLM_ARCH_UNKNOWN;
|
||||
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
|
||||
|
||||
std::string name = "n/a";
|
||||
|
||||
llama_hparams hparams = {};
|
||||
llama_vocab vocab;
|
||||
|
||||
struct ggml_tensor * tok_embd = nullptr;
|
||||
};
|
||||
/* End cringe code */
|
||||
// maps embeddings easily
|
||||
std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>& tokens, int n_embd, float* embds ) {
|
||||
std::vector<std::vector<float>> embedded( tokens.size() );
|
||||
for ( auto i = 0; i < tokens.size(); ++i ) {
|
||||
embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) );
|
||||
}
|
||||
return embedded;
|
||||
}
|
||||
|
||||
// handles adding either a token OR the embedding of that token into the batch
|
||||
// this really, really helps avoid needing to abuse the tokenizer
|
||||
// to-do: handle summing
|
||||
void batch_add( struct llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool logits = true, const std::vector<llama_seq_id> & seq_ids = {0} ) {
|
||||
void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} ) {
|
||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
||||
|
||||
// insert raw embedding instead
|
||||
if ( embds ) {
|
||||
for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[id * n_embd + i];
|
||||
// signals to not map the embedding from the array
|
||||
if ( id < 0 ) for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[i];
|
||||
else for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[id * n_embd + i];
|
||||
// insert token (never gets used here)
|
||||
} else {
|
||||
batch.token[batch.n_tokens] = id;
|
||||
}
|
||||
|
||||
batch.pos[batch.n_tokens] = pos;
|
||||
|
||||
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
|
||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||
}
|
||||
batch.logits[batch.n_tokens] = logits;
|
||||
for (size_t i = 0; i < seq_ids.size(); ++i) batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||
batch.logits[batch.n_tokens] = output ? 1 : 0;
|
||||
|
||||
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, batch.pos[batch.n_tokens], &batch.embd[batch.n_tokens], batch.logits[batch.n_tokens] );
|
||||
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens-1, id, pos, embds, output );
|
||||
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
// reads a waveform from disk
|
||||
bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
|
||||
uint32_t channels;
|
||||
uint32_t sample_rate;
|
||||
|
@ -171,6 +160,11 @@ bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (sample_rate != 24000) {
|
||||
fprintf(stderr, "%s: wav file is wrong sample rate\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n%s: Number of frames read = %lld.\n", __func__, total_frame_count);
|
||||
|
||||
audio_arr.resize(total_frame_count);
|
||||
|
@ -180,7 +174,7 @@ bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
// writes a waveform to disk
|
||||
void write_wav_on_disk(std::vector<float> & audio_arr, std::string dest_path) {
|
||||
drwav_data_format format;
|
||||
format.bitsPerSample = 32;
|
||||
|
@ -196,8 +190,8 @@ void write_wav_on_disk(std::vector<float> & audio_arr, std::string dest_path) {
|
|||
|
||||
fprintf(stderr, "%s: Number of frames written = %lld.\n", __func__, frames);
|
||||
}
|
||||
|
||||
std::vector<int32_t> encode_audio( struct encodec_context* ectx, const std::string& path ) {
|
||||
// reads a waveform from disk then encodes it
|
||||
std::vector<std::vector<int32_t>> encode_audio_from_disk( struct encodec_context* ectx, const std::string& path ) {
|
||||
// read audio from disk
|
||||
std::vector<float> wavform;
|
||||
|
||||
|
@ -214,10 +208,34 @@ std::vector<int32_t> encode_audio( struct encodec_context* ectx, const std::stri
|
|||
|
||||
int32_t* codes_data = encodec_get_codes( ectx );
|
||||
int n_codes = encodec_get_codes_size( ectx );
|
||||
int n_codebooks = 8;
|
||||
int n_frames = n_codes / n_codebooks;
|
||||
|
||||
return std::vector<int32_t>(codes_data, codes_data + n_codes);
|
||||
std::vector<int32_t> flattened_codes(codes_data, codes_data + n_codes);
|
||||
std::vector<std::vector<int32_t>> codes_2ds(8);
|
||||
|
||||
for ( auto l = 0; l < n_codebooks; ++l ) {
|
||||
codes_2ds[l].resize( n_frames );
|
||||
for ( auto i = 0; i < n_frames; ++i ) {
|
||||
codes_2ds[l][i] = flattened_codes[i + l * n_codebooks];
|
||||
}
|
||||
}
|
||||
|
||||
return codes_2ds;
|
||||
}
|
||||
std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector<int32_t>& codes ) {
|
||||
// decodes a 2D codebook into a waveform
|
||||
std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes_2d ) {
|
||||
int n_codebooks = codes_2d.size();
|
||||
int n_frames = codes_2d[0].size();
|
||||
|
||||
std::vector<int32_t> codes( n_frames * n_codebooks );
|
||||
|
||||
for ( auto l = 0; l < n_codebooks; ++l ) {
|
||||
for ( auto i = 0; i < n_frames; ++i ) {
|
||||
codes[i + l * n_codebooks] = codes_2d[l][i];
|
||||
}
|
||||
}
|
||||
|
||||
// decompress audio
|
||||
if (!encodec_decompress_audio(ectx, codes.data(), codes.size(), 1)) {
|
||||
printf("%s: error during decompression\n", __func__);
|
||||
|
@ -229,22 +247,224 @@ std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector
|
|||
const int audio_size = encodec_get_audio_size(ectx);
|
||||
return std::vector<float>(audio_data, audio_data + audio_size);
|
||||
}
|
||||
|
||||
const int EMBEDDING_MODE_PROM = 0;
|
||||
const int EMBEDDING_MODE_RESP_AR_NAR = 0;
|
||||
const int EMBEDDING_MODE_RESP_NAR_LEN = 0;
|
||||
|
||||
const int INFERENCE_MODE_LEN = 0;
|
||||
const int INFERENCE_MODE_AR = 1;
|
||||
const int INFERENCE_MODE_NAR_DEMASK = 2;
|
||||
const int INFERENCE_MODE_NAR = 4;
|
||||
|
||||
const int MODALITY_AR_NAR = 0;
|
||||
const int MODALITY_NAR_LEN = 0;
|
||||
|
||||
const int MAX_DURATION = 75 * 12;
|
||||
|
||||
// sums embeddings over a 2D "tensor"
|
||||
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) {
|
||||
std::vector<std::vector<float>> res( input.size() );
|
||||
res.resize( input[0].size() );
|
||||
for ( auto& e : res ) e.resize( n_embd );
|
||||
// iterate through rvq levels (only up to inclusive the target rvq level)
|
||||
for ( auto l = 0; l < input.size() && l <= rvq_l; ++l ) {
|
||||
int offset = 0;
|
||||
// handles the cringe logic I have
|
||||
if ( mode == EMBEDDING_MODE_RESP_AR_NAR ) {
|
||||
offset = input.size() == 1 ? 0 : 2;
|
||||
} else if ( mode == EMBEDDING_MODE_RESP_NAR_LEN ) {
|
||||
offset = input.size() == 1 ? 1 : 2;
|
||||
}
|
||||
// get tokens
|
||||
auto& tokens = input[l];
|
||||
// get output buffer
|
||||
auto& summed = res[l];
|
||||
// embed the current level's tokens
|
||||
auto embedded = map_embeddings( input[l], n_embd, embds[l + offset] );
|
||||
// iterate through embedded tokens
|
||||
for ( auto i = 0; i < tokens.size(); ++i ) {
|
||||
// sum with buffer
|
||||
for ( auto j = 0; j < n_embd; ++j ) summed[j] += embedded[i][j];
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_map, int mode ) {
|
||||
// keeps track of the position for each sequence
|
||||
size_t pos = 0;
|
||||
auto n_embd = embeddings_map.n_embd;
|
||||
|
||||
// insert text tokens
|
||||
for ( auto& id : input.phonemes ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
||||
pos = 0;
|
||||
// insert lang token
|
||||
batch_add( batch, input.lang, n_embd, embeddings_map.lang_embd, pos++, false );
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
||||
pos = 0;
|
||||
// insert rvq level token
|
||||
batch_add( batch, input.rvq_l, n_embd, embeddings_map.rvq_level_embd, pos++, false );
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
|
||||
pos = 0;
|
||||
// insert prom tokens
|
||||
auto summed_proms_embds = sum_embeddings( input.prom, n_embd, input.rvq_l, embeddings_map.prom_embds );
|
||||
for ( auto i = 0; i < summed_proms_embds.size(); ++i ) {
|
||||
batch_add( batch, -1, n_embd, &summed_proms_embds[i][0], pos++, false );
|
||||
}
|
||||
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, mode == INFERENCE_MODE_AR ); // set as the last logit if AR
|
||||
pos = 0;
|
||||
|
||||
// input starting len token
|
||||
if ( input.task == "len" ) {
|
||||
batch_add( batch, 0, n_embd, embeddings_map.len_embd, pos++, true );
|
||||
pos = 0;
|
||||
}
|
||||
|
||||
// insert resp tokens
|
||||
if ( !input.resp.empty() ) {
|
||||
auto summed_resps_embds = sum_embeddings( input.resp, n_embd, input.rvq_l, embeddings_map.resps_embds, mode == INFERENCE_MODE_AR ? EMBEDDING_MODE_RESP_AR_NAR : EMBEDDING_MODE_RESP_NAR_LEN );
|
||||
for ( auto i = 0; i < summed_resps_embds.size(); ++i ) {
|
||||
batch_add( batch, -1, n_embd, &summed_resps_embds[i][0], pos++, true );
|
||||
}
|
||||
pos = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// generation code, should handle all modalities easily
|
||||
std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) {
|
||||
llama_batch batch = llama_batch_init( 22500, embeddings_map.n_embd, 22500 );
|
||||
|
||||
// Decoding loop
|
||||
const auto t_main_start = ggml_time_us();
|
||||
int n_decode = 0;
|
||||
int rvq_l = input.rvq_l;
|
||||
llama_token stop_token = -1;
|
||||
|
||||
fill_batch( batch, input, embeddings_map, mode );
|
||||
|
||||
// determine how many logits we need
|
||||
int n_logits = 0;
|
||||
for ( auto i = 0; i < batch.n_tokens; ++i ) {
|
||||
if ( batch.logits[i] ) ++n_logits;
|
||||
}
|
||||
|
||||
if ( verbose ) printf("Prompt size: %i | Logits: %i\n", batch.n_tokens, n_logits);
|
||||
|
||||
// NAR mode, cap at one step
|
||||
if ( n_logits > 1 ) {
|
||||
max_tokens = n_logits;
|
||||
}
|
||||
|
||||
if ( n_logits == 0 ) {
|
||||
fprintf(stderr, "%s : no tokens to decode\n", __func__);
|
||||
return {};
|
||||
}
|
||||
|
||||
float* embds = NULL;
|
||||
int logit_range[2];
|
||||
if ( mode == INFERENCE_MODE_AR ) {
|
||||
logit_range[0] = embeddings_map.resp_embd_start[0];
|
||||
logit_range[1] = embeddings_map.resp_embd_start[1];
|
||||
|
||||
embds = embeddings_map.resps_embds[0];
|
||||
|
||||
stop_token = embeddings_map.resp_embd_start[1] - 1; // <|AR|0:STOP|>
|
||||
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
|
||||
logit_range[0] = embeddings_map.resp_embd_start[1];
|
||||
logit_range[1] = embeddings_map.resp_embd_start[2];
|
||||
|
||||
embds = embeddings_map.resps_embds[1];
|
||||
|
||||
stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|>
|
||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l];
|
||||
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l];
|
||||
|
||||
embds = embeddings_map.resps_embds[2];
|
||||
} else if ( mode == INFERENCE_MODE_LEN ) {
|
||||
logit_range[0] = embeddings_map.len_embd_start;
|
||||
logit_range[1] = embeddings_map.len_embd_start + 11;
|
||||
|
||||
embds = embeddings_map.len_embd;
|
||||
|
||||
stop_token = embeddings_map.len_embd_start + 10;
|
||||
}
|
||||
|
||||
llama_set_causal_attn( ctx, n_logits == 1 );
|
||||
|
||||
std::vector<llama_token> output_tokens;
|
||||
while ( output_tokens.size() < max_tokens ) {
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return output_tokens;
|
||||
}
|
||||
n_decode += 1;
|
||||
|
||||
// backwards iterate to start from beginning of sequence
|
||||
for ( auto i = n_logits; i > 0; --i ) {
|
||||
// filter logits
|
||||
auto* logits = llama_get_logits_ith( ctx, -i );
|
||||
for ( auto i = 0; i < embeddings_map.n_vocab; ++i ) {
|
||||
// out of target logit range, set to never happen
|
||||
if ( i < logit_range[0] || i >= logit_range[1] ) logits[i] = -INFINITY;
|
||||
}
|
||||
|
||||
// sample the next token
|
||||
auto t = llama_sampler_sample(smpl, ctx, -i);
|
||||
|
||||
if ( verbose ) {
|
||||
// print token for debugging
|
||||
char buf[256];
|
||||
llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
|
||||
printf("%s\n", buf );
|
||||
}
|
||||
|
||||
// is stop token
|
||||
if ( t == stop_token ) {
|
||||
max_tokens = 0;
|
||||
break;
|
||||
}
|
||||
|
||||
// offset into range
|
||||
t -= logit_range[0];
|
||||
|
||||
output_tokens.emplace_back(t);
|
||||
batch_add( batch, t, embeddings_map.n_embd, embds, output_tokens.size(), true );
|
||||
}
|
||||
}
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
if ( verbose ) {
|
||||
printf("\n");
|
||||
fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
llama_perf_sampler_print(smpl);
|
||||
llama_perf_context_print(ctx);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
return output_tokens;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
bool is_ar = true;
|
||||
// to-do: replace all of this with proper loading code
|
||||
std::vector<llama_token> phoneme_tokens = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2};
|
||||
llama_token lang_token = 0;
|
||||
llama_token rvq_level_token = 0;
|
||||
std::vector<std::vector<llama_token>> prompt_tokens = {
|
||||
{780,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,464,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,398,869,639,705,869,917,705,893,215,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,408,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,917,238,745,904,904,904,106,133,1019,1017,1017,395,883,87,519,594,1002,682,996,540,186,1019,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,25,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,159,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,62,172,638,359,562,138,839,846,775,556,688,1006,917,297,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,538,519,762,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,310,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,368,359,519,996,616,464,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,876,876,1017,63,329},
|
||||
};
|
||||
std::vector<std::vector<llama_token>> response_tokens = {
|
||||
{922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665},
|
||||
};
|
||||
int32_t ngl = 0;
|
||||
int modality = MODALITY_AR_NAR;
|
||||
input_t input{};
|
||||
embeddings_t embeddings_map{};
|
||||
|
||||
input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2};
|
||||
|
||||
std::string vall_e_model_path = "./data/vall_e-q8_0.gguf";
|
||||
std::string encodec_model_path = "./data/encodec.bin";
|
||||
int32_t ngl = 0;
|
||||
|
||||
std::string input_prompt_path = "./data/prom.wav";
|
||||
std::string output_response_path = "./data/resp.wav";
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
@ -255,7 +475,12 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
encodec_set_target_bandwidth(ectx, 24);
|
||||
encodec_set_target_bandwidth(ectx, 6);
|
||||
encodec_set_sample_rate(ectx, 24000);
|
||||
|
||||
// load wavform
|
||||
input.prom = encode_audio_from_disk(ectx, input_prompt_path);
|
||||
//input.resp = encode_audio_from_disk(ectx, output_response_path);
|
||||
|
||||
// initialize the models
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
@ -272,9 +497,9 @@ int main(int argc, char ** argv) {
|
|||
ctx_params.n_ctx = 22500;
|
||||
ctx_params.n_batch = 22500;
|
||||
ctx_params.no_perf = false;
|
||||
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||
|
||||
ctx_params.attention_type = is_ar ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL;
|
||||
|
||||
// create two contexts, one's that causally, the other that isn't, because pain
|
||||
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
|
@ -284,14 +509,19 @@ int main(int argc, char ** argv) {
|
|||
// initialize the sampler
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
llama_sampler * smpl_ar = llama_sampler_chain_init(sparams);
|
||||
llama_sampler * smpl_nar = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(20));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(0.9, 20));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0));
|
||||
// llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
|
||||
|
||||
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
|
||||
|
||||
// prepare batch
|
||||
auto n_embd = llama_n_embd( model );
|
||||
auto n_vocab = llama_n_vocab( model );
|
||||
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
|
||||
|
||||
// grab input embeddings
|
||||
std::vector<float> embds( n_embd * n_vocab );
|
||||
|
@ -300,151 +530,61 @@ int main(int argc, char ** argv) {
|
|||
if ( ggml_is_quantized(model->tok_embd->type) ) {
|
||||
qtype->to_float(model->tok_embd->data, embds.data(), embds.size());
|
||||
}
|
||||
// update mapping
|
||||
embeddings_map.init( n_embd, n_vocab, embds.data() );
|
||||
|
||||
// to-do: derive these offsets from the tokenizer itself
|
||||
// to-do: clean this up, probably make it at parity to inputs_to_embeddings
|
||||
int text_embd_start = 0; // <unk>
|
||||
int rvq_level_embd_start = 17666; // <|RVQ:0>
|
||||
int len_embd_start = 17674; // <|len:0|>
|
||||
int lang_embd_start = 17686; // <|lang:en|>
|
||||
int task_embd_start = 17692; // <|task:tts|>
|
||||
int sep_embd_start = 17685; // <|sep|>
|
||||
int prom_embd_start[] = {
|
||||
256 + (1024 * 0), // <|P|0:0|>
|
||||
256 + (1024 * 1), // <|P|1:0|>
|
||||
256 + (1024 * 2), // <|P|2:0|>
|
||||
256 + (1024 * 3), // <|P|3:0|>
|
||||
256 + (1024 * 4), // <|P|4:0|>
|
||||
256 + (1024 * 5), // <|P|5:0|>
|
||||
256 + (1024 * 6), // <|P|6:0|>
|
||||
256 + (1024 * 7), // <|P|7:0|>
|
||||
};
|
||||
int resp_embd_start[] = {
|
||||
8448, // <|AR|0:0|>
|
||||
9473, // <|NAR|0:0|>
|
||||
10498 + (1024 * 0), // <|NAR|0:1|>
|
||||
10498 + (1024 * 1), // <|NAR|1:2|>
|
||||
10498 + (1024 * 2), // <|NAR|2:3|>
|
||||
10498 + (1024 * 3), // <|NAR|3:4|>
|
||||
10498 + (1024 * 4), // <|NAR|4:5|>
|
||||
10498 + (1024 * 5), // <|NAR|5:6|>
|
||||
10498 + (1024 * 6), // <|NAR|6:7|>
|
||||
};
|
||||
|
||||
float* text_embds = &embds[text_embd_start * n_embd];
|
||||
float* rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
|
||||
float* len_embd = &embds[len_embd_start * n_embd];
|
||||
float* lang_embd = &embds[lang_embd_start * n_embd];
|
||||
float* task_embd = &embds[task_embd_start * n_embd];
|
||||
float* sep_embd = &embds[sep_embd_start * n_embd];
|
||||
|
||||
float* prom_embds[] = {
|
||||
&embds[prom_embd_start[0] * n_embd],
|
||||
&embds[prom_embd_start[1] * n_embd],
|
||||
&embds[prom_embd_start[2] * n_embd],
|
||||
&embds[prom_embd_start[3] * n_embd],
|
||||
&embds[prom_embd_start[4] * n_embd],
|
||||
&embds[prom_embd_start[5] * n_embd],
|
||||
&embds[prom_embd_start[6] * n_embd],
|
||||
&embds[prom_embd_start[7] * n_embd],
|
||||
};
|
||||
float* resps_embds[] = {
|
||||
&embds[resp_embd_start[0] * n_embd],
|
||||
&embds[resp_embd_start[1] * n_embd],
|
||||
&embds[resp_embd_start[2] * n_embd],
|
||||
&embds[resp_embd_start[3] * n_embd],
|
||||
&embds[resp_embd_start[4] * n_embd],
|
||||
&embds[resp_embd_start[5] * n_embd],
|
||||
&embds[resp_embd_start[6] * n_embd],
|
||||
&embds[resp_embd_start[7] * n_embd],
|
||||
&embds[resp_embd_start[8] * n_embd],
|
||||
};
|
||||
|
||||
// insert into batch
|
||||
{
|
||||
// keeps track of the position for each sequence
|
||||
size_t pos = 0;
|
||||
|
||||
// insert text tokens
|
||||
for ( auto& id : phoneme_tokens ) batch_add( batch, id, n_embd, text_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
|
||||
pos = 0;
|
||||
// insert lang token
|
||||
batch_add( batch, lang_token, n_embd, lang_embd, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
|
||||
pos = 0;
|
||||
// insert rvq level token
|
||||
batch_add( batch, rvq_level_token, n_embd, rvq_level_embd, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
|
||||
pos = 0;
|
||||
// insert prom tokens
|
||||
// to-do: handle summing
|
||||
for ( auto l = 0; l < prompt_tokens.size(); ++l ) {
|
||||
for ( auto& id : prompt_tokens[l] ) batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
|
||||
}
|
||||
batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar );
|
||||
pos = 0;
|
||||
|
||||
// fill in masked tokens
|
||||
if ( !is_ar ) {
|
||||
for ( auto i = 0; i < response_tokens[0].size(); ++i ) batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
|
||||
}
|
||||
pos = 0;
|
||||
}
|
||||
|
||||
// Decoding loop
|
||||
const auto t_main_start = ggml_time_us();
|
||||
int n_decode = 0;
|
||||
|
||||
// to-do: handle other levels
|
||||
std::vector<llama_token> resps_tokens;
|
||||
while ( resps_tokens.size() < 32 ) {
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
n_decode += 1;
|
||||
|
||||
// align to AR's classifier
|
||||
// to-do: derive from tokenizer
|
||||
int range[] = { resp_embd_start[0], resp_embd_start[1] };
|
||||
auto* logits = llama_get_logits_ith( ctx, -1 );
|
||||
for ( auto i = 0; i < n_vocab; ++i ) {
|
||||
if ( i < range[0] || i >= range[1] ) {
|
||||
logits[i] = -INFINITY;
|
||||
// inference
|
||||
std::vector<llama_token> output_tokens;
|
||||
// NAR-len demasking
|
||||
if ( modality == MODALITY_NAR_LEN ) {
|
||||
// inference len
|
||||
input.task = "len";
|
||||
output_tokens = generate( ctx, model, smpl_nar, input, embeddings_map, 5, INFERENCE_MODE_LEN );
|
||||
int len = 0; {
|
||||
int digit = 1;
|
||||
for (int i = output_tokens.size() - 1; i >= 0; i--) {
|
||||
len += output_tokens[i] * digit;
|
||||
digit *= 10;
|
||||
}
|
||||
}
|
||||
// cap for now
|
||||
if ( len > MAX_DURATION ) len = MAX_DURATION;
|
||||
|
||||
// sample the next token
|
||||
auto t = llama_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
// is stop token
|
||||
if ( t == resp_embd_start[1] - 1 ) { // <|AR|0:STOP|>
|
||||
break;
|
||||
// fill with mask tokens
|
||||
input.resp.resize(1);
|
||||
for ( auto i = 0; i < len; ++i ) {
|
||||
input.resp[0].emplace_back( embeddings_map.resp_embd_start[3] - 1 ); // fill with masked tokens
|
||||
}
|
||||
|
||||
char buf[256];
|
||||
llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
|
||||
printf("%s\n", buf );
|
||||
|
||||
batch_add( batch, 0, n_embd, resps_embds[0], resps_tokens.size(), true );
|
||||
resps_tokens.emplace_back(t);
|
||||
// inference NAR-len 0
|
||||
input.task = "tts";
|
||||
for ( auto l = 0; l < 8; ++l ) {
|
||||
input.rvq_l = l;
|
||||
output_tokens = generate( ctx, model, smpl_nar, input, embeddings_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
|
||||
input.resp.emplace_back( output_tokens );
|
||||
}
|
||||
// AR+NAR
|
||||
} else if ( modality == MODALITY_AR_NAR ){
|
||||
input.task = "tts";
|
||||
for ( auto l = 0; l < 8; ++l ) {
|
||||
input.rvq_l = l;
|
||||
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, embeddings_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
|
||||
input.resp.emplace_back( output_tokens );
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
// write audio to disk
|
||||
auto waveform = decode_audio( ectx, input.resp );
|
||||
write_wav_on_disk( waveform, output_response_path );
|
||||
|
||||
fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
// cleanup
|
||||
encodec_free(ectx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
llama_perf_sampler_print(smpl);
|
||||
llama_perf_context_print(ctx);
|
||||
fprintf(stderr, "\n");
|
||||
llama_sampler_free(smpl_nar);
|
||||
llama_sampler_free(smpl_ar);
|
||||
|
||||
// encodec_free(ectx);
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
|
||||
llama_free_model(model);
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Reference in New Issue
Block a user