more updates to vall_e.cpp

This commit is contained in:
mrq 2024-12-21 19:16:44 -06:00
parent 503124d0d3
commit 1b4a69ce29
2 changed files with 431 additions and 289 deletions

View File

@ -24,11 +24,13 @@ Run `make`.
* [x] juggle the output head / classifier properly
* [ ] phonemize text
* [ ] tokenize phonemes
* [ ] load audio from disk
* [ ] encode audio
* [ ] sum embeddings for the `prom` and prior `resp`s
* [x] load audio from disk
* [x] encode audio
* [x] sum embeddings for the `prom` and prior `resp`s
* [x] `AR` sampling
* [ ] `NAR-len` demasking sampling
* [ ] `NAR` sampling
* [ ] decode audio to disk
* [ ] a functional CLI
* [ ] actually make it work
* it seems naively stitching the model together isn't good enough since the output is wrong

View File

@ -13,151 +13,140 @@
#include <array>
#include <iostream>
/* Begin cringe so I can access the model's tok_embd */
// it needs to be copied so the struct layout is exactly as it is under llama.cpp
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
enum e_model {
MODEL_UNKNOWN,
// stores the raw inputs to be fed
struct input_t {
std::string task = "tts";
std::vector<llama_token> phonemes = {};
llama_token lang = 0;
llama_token rvq_l = 0;
std::vector<std::vector<llama_token>> prom = {};
std::vector<std::vector<llama_token>> resp = {};
};
// handles all the cringe logic of slicing embeddings
struct embeddings_t {
int n_embd = 0;
int n_vocab = 0;
float* embds = NULL;
int text_embd_start = 0; // <unk>
int rvq_level_embd_start = 17666; // <|RVQ:0>
int len_embd_start = 17674; // <|len:0|>
int lang_embd_start = 17686; // <|lang:en|>
int task_embd_start = 17692; // <|task:tts|>
int sep_embd_start = 17685; // <|sep|>
int prom_embd_start[8] = {
256 + (1024 * 0), // <|P|0:0|>
256 + (1024 * 1), // <|P|1:0|>
256 + (1024 * 2), // <|P|2:0|>
256 + (1024 * 3), // <|P|3:0|>
256 + (1024 * 4), // <|P|4:0|>
256 + (1024 * 5), // <|P|5:0|>
256 + (1024 * 6), // <|P|6:0|>
256 + (1024 * 7), // <|P|7:0|>
};
int resp_embd_start[9] = {
8448, // <|AR|0:0|>
9473, // <|NAR|0:0|>
10498 + (1024 * 0), // <|NAR|0:1|>
10498 + (1024 * 1), // <|NAR|1:2|>
10498 + (1024 * 2), // <|NAR|2:3|>
10498 + (1024 * 3), // <|NAR|3:4|>
10498 + (1024 * 4), // <|NAR|4:5|>
10498 + (1024 * 5), // <|NAR|5:6|>
10498 + (1024 * 6), // <|NAR|6:7|>
};
float* text_embds = NULL; // &embds[text_embd_start * n_embd];
float* rvq_level_embd = NULL; // &embds[rvq_level_embd_start * n_embd];
float* len_embd = NULL; // &embds[len_embd_start * n_embd];
float* lang_embd = NULL; // &embds[lang_embd_start * n_embd];
float* task_embd = NULL; // &embds[task_embd_start * n_embd];
float* sep_embd = NULL; // &embds[sep_embd_start * n_embd];
float* prom_embds[8] = {
NULL, // &embds[prom_embd_start[0] * n_embd],
NULL, // &embds[prom_embd_start[1] * n_embd],
NULL, // &embds[prom_embd_start[2] * n_embd],
NULL, // &embds[prom_embd_start[3] * n_embd],
NULL, // &embds[prom_embd_start[4] * n_embd],
NULL, // &embds[prom_embd_start[5] * n_embd],
NULL, // &embds[prom_embd_start[6] * n_embd],
NULL, // &embds[prom_embd_start[7] * n_embd],
};
float* resps_embds[9] = {
NULL, // &embds[resp_embd_start[0] * n_embd],
NULL, // &embds[resp_embd_start[1] * n_embd],
NULL, // &embds[resp_embd_start[2] * n_embd],
NULL, // &embds[resp_embd_start[3] * n_embd],
NULL, // &embds[resp_embd_start[4] * n_embd],
NULL, // &embds[resp_embd_start[5] * n_embd],
NULL, // &embds[resp_embd_start[6] * n_embd],
NULL, // &embds[resp_embd_start[7] * n_embd],
NULL, // &embds[resp_embd_start[8] * n_embd],
};
embeddings_t( int n_embd = 0, int n_vocab = 0, float* embds = NULL ) {
init( n_embd, n_vocab, embds );
}
void init( int n_embd, int n_vocab, float* embds = NULL ) {
if ( !n_embd || !n_vocab || !embds ) return;
this->n_embd = n_embd;
this->n_vocab = n_vocab;
this->embds = embds;
text_embds = &embds[text_embd_start * n_embd];
rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
len_embd = &embds[len_embd_start * n_embd];
lang_embd = &embds[lang_embd_start * n_embd];
task_embd = &embds[task_embd_start * n_embd];
sep_embd = &embds[sep_embd_start * n_embd];
for ( auto i = 0; i < 8; ++i ) prom_embds[i] = &embds[prom_embd_start[i] * n_embd];
for ( auto i = 0; i < 9; ++i ) resps_embds[i] = &embds[resp_embd_start[i] * n_embd];
}
};
enum llm_arch {
LLM_ARCH_UNKNOWN,
};
struct llama_hparams_posnet {
uint32_t n_embd;
uint32_t n_layer;
};
struct llama_hparams_convnext {
uint32_t n_embd;
uint32_t n_layer;
};
struct llama_hparams {
bool vocab_only;
bool rope_finetuned;
bool use_par_res;
bool swin_norm;
uint32_t n_vocab = 0;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0;
// for WavTokenizer
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0;
uint32_t n_lora_kv = 0;
uint32_t n_ff_exp = 0;
uint32_t n_ff_shexp = 0;
uint32_t n_expert_shared = 0;
float expert_weights_scale = 0.0;
float f_norm_eps;
float f_norm_rms_eps;
float f_norm_group_eps;
uint32_t n_norm_groups;
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
// for RWKV
uint32_t rescale_every_n_layers = 0;
uint32_t time_mix_extra_dim = 0;
uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0;
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
int rope_sections[4];
// for State Space Models
uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
bool ssm_dt_b_c_rms = false;
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;
// Additional scale factors (Granite/Granite MoE)
float f_residual_scale = 0.0f;
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
};
struct llama_model {
e_model type = MODEL_UNKNOWN;
llm_arch arch = LLM_ARCH_UNKNOWN;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
std::string name = "n/a";
llama_hparams hparams = {};
llama_vocab vocab;
struct ggml_tensor * tok_embd = nullptr;
};
/* End cringe code */
// maps embeddings easily
std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>& tokens, int n_embd, float* embds ) {
std::vector<std::vector<float>> embedded( tokens.size() );
for ( auto i = 0; i < tokens.size(); ++i ) {
embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) );
}
return embedded;
}
// handles adding either a token OR the embedding of that token into the batch
// this really, really helps avoid needing to abuse the tokenizer
// to-do: handle summing
void batch_add( struct llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool logits = true, const std::vector<llama_seq_id> & seq_ids = {0} ) {
void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} ) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
// insert raw embedding instead
if ( embds ) {
for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[id * n_embd + i];
// signals to not map the embedding from the array
if ( id < 0 ) for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[i];
else for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[id * n_embd + i];
// insert token (never gets used here)
} else {
batch.token[batch.n_tokens] = id;
}
batch.pos[batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits[batch.n_tokens] = logits;
for (size_t i = 0; i < seq_ids.size(); ++i) batch.seq_id[batch.n_tokens][i] = seq_ids[i];
batch.logits[batch.n_tokens] = output ? 1 : 0;
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, batch.pos[batch.n_tokens], &batch.embd[batch.n_tokens], batch.logits[batch.n_tokens] );
// printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens-1, id, pos, embds, output );
batch.n_tokens++;
}
// reads a waveform from disk
bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
uint32_t channels;
uint32_t sample_rate;
@ -171,6 +160,11 @@ bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
return false;
}
if (sample_rate != 24000) {
fprintf(stderr, "%s: wav file is wrong sample rate\n", __func__);
return false;
}
fprintf(stderr, "\n%s: Number of frames read = %lld.\n", __func__, total_frame_count);
audio_arr.resize(total_frame_count);
@ -180,7 +174,7 @@ bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
return true;
}
// writes a waveform to disk
void write_wav_on_disk(std::vector<float> & audio_arr, std::string dest_path) {
drwav_data_format format;
format.bitsPerSample = 32;
@ -196,8 +190,8 @@ void write_wav_on_disk(std::vector<float> & audio_arr, std::string dest_path) {
fprintf(stderr, "%s: Number of frames written = %lld.\n", __func__, frames);
}
std::vector<int32_t> encode_audio( struct encodec_context* ectx, const std::string& path ) {
// reads a waveform from disk then encodes it
std::vector<std::vector<int32_t>> encode_audio_from_disk( struct encodec_context* ectx, const std::string& path ) {
// read audio from disk
std::vector<float> wavform;
@ -214,10 +208,34 @@ std::vector<int32_t> encode_audio( struct encodec_context* ectx, const std::stri
int32_t* codes_data = encodec_get_codes( ectx );
int n_codes = encodec_get_codes_size( ectx );
int n_codebooks = 8;
int n_frames = n_codes / n_codebooks;
return std::vector<int32_t>(codes_data, codes_data + n_codes);
std::vector<int32_t> flattened_codes(codes_data, codes_data + n_codes);
std::vector<std::vector<int32_t>> codes_2ds(8);
for ( auto l = 0; l < n_codebooks; ++l ) {
codes_2ds[l].resize( n_frames );
for ( auto i = 0; i < n_frames; ++i ) {
codes_2ds[l][i] = flattened_codes[i + l * n_codebooks];
}
}
return codes_2ds;
}
std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector<int32_t>& codes ) {
// decodes a 2D codebook into a waveform
std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes_2d ) {
int n_codebooks = codes_2d.size();
int n_frames = codes_2d[0].size();
std::vector<int32_t> codes( n_frames * n_codebooks );
for ( auto l = 0; l < n_codebooks; ++l ) {
for ( auto i = 0; i < n_frames; ++i ) {
codes[i + l * n_codebooks] = codes_2d[l][i];
}
}
// decompress audio
if (!encodec_decompress_audio(ectx, codes.data(), codes.size(), 1)) {
printf("%s: error during decompression\n", __func__);
@ -229,22 +247,224 @@ std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector
const int audio_size = encodec_get_audio_size(ectx);
return std::vector<float>(audio_data, audio_data + audio_size);
}
const int EMBEDDING_MODE_PROM = 0;
const int EMBEDDING_MODE_RESP_AR_NAR = 0;
const int EMBEDDING_MODE_RESP_NAR_LEN = 0;
const int INFERENCE_MODE_LEN = 0;
const int INFERENCE_MODE_AR = 1;
const int INFERENCE_MODE_NAR_DEMASK = 2;
const int INFERENCE_MODE_NAR = 4;
const int MODALITY_AR_NAR = 0;
const int MODALITY_NAR_LEN = 0;
const int MAX_DURATION = 75 * 12;
// sums embeddings over a 2D "tensor"
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) {
std::vector<std::vector<float>> res( input.size() );
res.resize( input[0].size() );
for ( auto& e : res ) e.resize( n_embd );
// iterate through rvq levels (only up to inclusive the target rvq level)
for ( auto l = 0; l < input.size() && l <= rvq_l; ++l ) {
int offset = 0;
// handles the cringe logic I have
if ( mode == EMBEDDING_MODE_RESP_AR_NAR ) {
offset = input.size() == 1 ? 0 : 2;
} else if ( mode == EMBEDDING_MODE_RESP_NAR_LEN ) {
offset = input.size() == 1 ? 1 : 2;
}
// get tokens
auto& tokens = input[l];
// get output buffer
auto& summed = res[l];
// embed the current level's tokens
auto embedded = map_embeddings( input[l], n_embd, embds[l + offset] );
// iterate through embedded tokens
for ( auto i = 0; i < tokens.size(); ++i ) {
// sum with buffer
for ( auto j = 0; j < n_embd; ++j ) summed[j] += embedded[i][j];
}
}
return res;
}
void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_map, int mode ) {
// keeps track of the position for each sequence
size_t pos = 0;
auto n_embd = embeddings_map.n_embd;
// insert text tokens
for ( auto& id : input.phonemes ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false );
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
pos = 0;
// insert lang token
batch_add( batch, input.lang, n_embd, embeddings_map.lang_embd, pos++, false );
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
pos = 0;
// insert rvq level token
batch_add( batch, input.rvq_l, n_embd, embeddings_map.rvq_level_embd, pos++, false );
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false );
pos = 0;
// insert prom tokens
auto summed_proms_embds = sum_embeddings( input.prom, n_embd, input.rvq_l, embeddings_map.prom_embds );
for ( auto i = 0; i < summed_proms_embds.size(); ++i ) {
batch_add( batch, -1, n_embd, &summed_proms_embds[i][0], pos++, false );
}
batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, mode == INFERENCE_MODE_AR ); // set as the last logit if AR
pos = 0;
// input starting len token
if ( input.task == "len" ) {
batch_add( batch, 0, n_embd, embeddings_map.len_embd, pos++, true );
pos = 0;
}
// insert resp tokens
if ( !input.resp.empty() ) {
auto summed_resps_embds = sum_embeddings( input.resp, n_embd, input.rvq_l, embeddings_map.resps_embds, mode == INFERENCE_MODE_AR ? EMBEDDING_MODE_RESP_AR_NAR : EMBEDDING_MODE_RESP_NAR_LEN );
for ( auto i = 0; i < summed_resps_embds.size(); ++i ) {
batch_add( batch, -1, n_embd, &summed_resps_embds[i][0], pos++, true );
}
pos = 0;
}
}
// generation code, should handle all modalities easily
std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) {
llama_batch batch = llama_batch_init( 22500, embeddings_map.n_embd, 22500 );
// Decoding loop
const auto t_main_start = ggml_time_us();
int n_decode = 0;
int rvq_l = input.rvq_l;
llama_token stop_token = -1;
fill_batch( batch, input, embeddings_map, mode );
// determine how many logits we need
int n_logits = 0;
for ( auto i = 0; i < batch.n_tokens; ++i ) {
if ( batch.logits[i] ) ++n_logits;
}
if ( verbose ) printf("Prompt size: %i | Logits: %i\n", batch.n_tokens, n_logits);
// NAR mode, cap at one step
if ( n_logits > 1 ) {
max_tokens = n_logits;
}
if ( n_logits == 0 ) {
fprintf(stderr, "%s : no tokens to decode\n", __func__);
return {};
}
float* embds = NULL;
int logit_range[2];
if ( mode == INFERENCE_MODE_AR ) {
logit_range[0] = embeddings_map.resp_embd_start[0];
logit_range[1] = embeddings_map.resp_embd_start[1];
embds = embeddings_map.resps_embds[0];
stop_token = embeddings_map.resp_embd_start[1] - 1; // <|AR|0:STOP|>
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
logit_range[0] = embeddings_map.resp_embd_start[1];
logit_range[1] = embeddings_map.resp_embd_start[2];
embds = embeddings_map.resps_embds[1];
stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|>
} else if ( mode == INFERENCE_MODE_NAR ) {
logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l];
logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l];
embds = embeddings_map.resps_embds[2];
} else if ( mode == INFERENCE_MODE_LEN ) {
logit_range[0] = embeddings_map.len_embd_start;
logit_range[1] = embeddings_map.len_embd_start + 11;
embds = embeddings_map.len_embd;
stop_token = embeddings_map.len_embd_start + 10;
}
llama_set_causal_attn( ctx, n_logits == 1 );
std::vector<llama_token> output_tokens;
while ( output_tokens.size() < max_tokens ) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return output_tokens;
}
n_decode += 1;
// backwards iterate to start from beginning of sequence
for ( auto i = n_logits; i > 0; --i ) {
// filter logits
auto* logits = llama_get_logits_ith( ctx, -i );
for ( auto i = 0; i < embeddings_map.n_vocab; ++i ) {
// out of target logit range, set to never happen
if ( i < logit_range[0] || i >= logit_range[1] ) logits[i] = -INFINITY;
}
// sample the next token
auto t = llama_sampler_sample(smpl, ctx, -i);
if ( verbose ) {
// print token for debugging
char buf[256];
llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
printf("%s\n", buf );
}
// is stop token
if ( t == stop_token ) {
max_tokens = 0;
break;
}
// offset into range
t -= logit_range[0];
output_tokens.emplace_back(t);
batch_add( batch, t, embeddings_map.n_embd, embds, output_tokens.size(), true );
}
}
const auto t_main_end = ggml_time_us();
if ( verbose ) {
printf("\n");
fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
fprintf(stderr, "\n");
llama_perf_sampler_print(smpl);
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
}
llama_batch_free(batch);
return output_tokens;
}
int main(int argc, char ** argv) {
bool is_ar = true;
// to-do: replace all of this with proper loading code
std::vector<llama_token> phoneme_tokens = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2};
llama_token lang_token = 0;
llama_token rvq_level_token = 0;
std::vector<std::vector<llama_token>> prompt_tokens = {
{780,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,464,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,398,869,639,705,869,917,705,893,215,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,408,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,917,238,745,904,904,904,106,133,1019,1017,1017,395,883,87,519,594,1002,682,996,540,186,1019,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,25,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,159,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,62,172,638,359,562,138,839,846,775,556,688,1006,917,297,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,538,519,762,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,310,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,368,359,519,996,616,464,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,876,876,1017,63,329},
};
std::vector<std::vector<llama_token>> response_tokens = {
{922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665},
};
int32_t ngl = 0;
int modality = MODALITY_AR_NAR;
input_t input{};
embeddings_t embeddings_map{};
input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2};
std::string vall_e_model_path = "./data/vall_e-q8_0.gguf";
std::string encodec_model_path = "./data/encodec.bin";
int32_t ngl = 0;
std::string input_prompt_path = "./data/prom.wav";
std::string output_response_path = "./data/resp.wav";
// load dynamic backends
ggml_backend_load_all();
@ -255,7 +475,12 @@ int main(int argc, char ** argv) {
return 1;
}
encodec_set_target_bandwidth(ectx, 24);
encodec_set_target_bandwidth(ectx, 6);
encodec_set_sample_rate(ectx, 24000);
// load wavform
input.prom = encode_audio_from_disk(ectx, input_prompt_path);
//input.resp = encode_audio_from_disk(ectx, output_response_path);
// initialize the models
llama_model_params model_params = llama_model_default_params();
@ -272,9 +497,9 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = 22500;
ctx_params.n_batch = 22500;
ctx_params.no_perf = false;
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
ctx_params.attention_type = is_ar ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL;
// create two contexts, one's that causally, the other that isn't, because pain
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@ -284,14 +509,19 @@ int main(int argc, char ** argv) {
// initialize the sampler
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler * smpl_ar = llama_sampler_chain_init(sparams);
llama_sampler * smpl_nar = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(20));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(0.9, 20));
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0));
// llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
// prepare batch
auto n_embd = llama_n_embd( model );
auto n_vocab = llama_n_vocab( model );
llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx );
// grab input embeddings
std::vector<float> embds( n_embd * n_vocab );
@ -300,151 +530,61 @@ int main(int argc, char ** argv) {
if ( ggml_is_quantized(model->tok_embd->type) ) {
qtype->to_float(model->tok_embd->data, embds.data(), embds.size());
}
// update mapping
embeddings_map.init( n_embd, n_vocab, embds.data() );
// to-do: derive these offsets from the tokenizer itself
// to-do: clean this up, probably make it at parity to inputs_to_embeddings
int text_embd_start = 0; // <unk>
int rvq_level_embd_start = 17666; // <|RVQ:0>
int len_embd_start = 17674; // <|len:0|>
int lang_embd_start = 17686; // <|lang:en|>
int task_embd_start = 17692; // <|task:tts|>
int sep_embd_start = 17685; // <|sep|>
int prom_embd_start[] = {
256 + (1024 * 0), // <|P|0:0|>
256 + (1024 * 1), // <|P|1:0|>
256 + (1024 * 2), // <|P|2:0|>
256 + (1024 * 3), // <|P|3:0|>
256 + (1024 * 4), // <|P|4:0|>
256 + (1024 * 5), // <|P|5:0|>
256 + (1024 * 6), // <|P|6:0|>
256 + (1024 * 7), // <|P|7:0|>
};
int resp_embd_start[] = {
8448, // <|AR|0:0|>
9473, // <|NAR|0:0|>
10498 + (1024 * 0), // <|NAR|0:1|>
10498 + (1024 * 1), // <|NAR|1:2|>
10498 + (1024 * 2), // <|NAR|2:3|>
10498 + (1024 * 3), // <|NAR|3:4|>
10498 + (1024 * 4), // <|NAR|4:5|>
10498 + (1024 * 5), // <|NAR|5:6|>
10498 + (1024 * 6), // <|NAR|6:7|>
};
float* text_embds = &embds[text_embd_start * n_embd];
float* rvq_level_embd = &embds[rvq_level_embd_start * n_embd];
float* len_embd = &embds[len_embd_start * n_embd];
float* lang_embd = &embds[lang_embd_start * n_embd];
float* task_embd = &embds[task_embd_start * n_embd];
float* sep_embd = &embds[sep_embd_start * n_embd];
float* prom_embds[] = {
&embds[prom_embd_start[0] * n_embd],
&embds[prom_embd_start[1] * n_embd],
&embds[prom_embd_start[2] * n_embd],
&embds[prom_embd_start[3] * n_embd],
&embds[prom_embd_start[4] * n_embd],
&embds[prom_embd_start[5] * n_embd],
&embds[prom_embd_start[6] * n_embd],
&embds[prom_embd_start[7] * n_embd],
};
float* resps_embds[] = {
&embds[resp_embd_start[0] * n_embd],
&embds[resp_embd_start[1] * n_embd],
&embds[resp_embd_start[2] * n_embd],
&embds[resp_embd_start[3] * n_embd],
&embds[resp_embd_start[4] * n_embd],
&embds[resp_embd_start[5] * n_embd],
&embds[resp_embd_start[6] * n_embd],
&embds[resp_embd_start[7] * n_embd],
&embds[resp_embd_start[8] * n_embd],
};
// insert into batch
{
// keeps track of the position for each sequence
size_t pos = 0;
// insert text tokens
for ( auto& id : phoneme_tokens ) batch_add( batch, id, n_embd, text_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0;
// insert lang token
batch_add( batch, lang_token, n_embd, lang_embd, pos++, false );
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0;
// insert rvq level token
batch_add( batch, rvq_level_token, n_embd, rvq_level_embd, pos++, false );
batch_add( batch, 0, n_embd, sep_embd, pos++, false );
pos = 0;
// insert prom tokens
// to-do: handle summing
for ( auto l = 0; l < prompt_tokens.size(); ++l ) {
for ( auto& id : prompt_tokens[l] ) batch_add( batch, id, n_embd, prom_embds[l], pos++, false );
}
batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar );
pos = 0;
// fill in masked tokens
if ( !is_ar ) {
for ( auto i = 0; i < response_tokens[0].size(); ++i ) batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true );
}
pos = 0;
}
// Decoding loop
const auto t_main_start = ggml_time_us();
int n_decode = 0;
// to-do: handle other levels
std::vector<llama_token> resps_tokens;
while ( resps_tokens.size() < 32 ) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
n_decode += 1;
// align to AR's classifier
// to-do: derive from tokenizer
int range[] = { resp_embd_start[0], resp_embd_start[1] };
auto* logits = llama_get_logits_ith( ctx, -1 );
for ( auto i = 0; i < n_vocab; ++i ) {
if ( i < range[0] || i >= range[1] ) {
logits[i] = -INFINITY;
// inference
std::vector<llama_token> output_tokens;
// NAR-len demasking
if ( modality == MODALITY_NAR_LEN ) {
// inference len
input.task = "len";
output_tokens = generate( ctx, model, smpl_nar, input, embeddings_map, 5, INFERENCE_MODE_LEN );
int len = 0; {
int digit = 1;
for (int i = output_tokens.size() - 1; i >= 0; i--) {
len += output_tokens[i] * digit;
digit *= 10;
}
}
// cap for now
if ( len > MAX_DURATION ) len = MAX_DURATION;
// sample the next token
auto t = llama_sampler_sample(smpl, ctx, -1);
// is stop token
if ( t == resp_embd_start[1] - 1 ) { // <|AR|0:STOP|>
break;
// fill with mask tokens
input.resp.resize(1);
for ( auto i = 0; i < len; ++i ) {
input.resp[0].emplace_back( embeddings_map.resp_embd_start[3] - 1 ); // fill with masked tokens
}
char buf[256];
llama_token_to_piece( model, t, buf, sizeof(buf), 0, true );
printf("%s\n", buf );
batch_add( batch, 0, n_embd, resps_embds[0], resps_tokens.size(), true );
resps_tokens.emplace_back(t);
// inference NAR-len 0
input.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l;
output_tokens = generate( ctx, model, smpl_nar, input, embeddings_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR );
input.resp.emplace_back( output_tokens );
}
// AR+NAR
} else if ( modality == MODALITY_AR_NAR ){
input.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
input.rvq_l = l;
output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, embeddings_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR );
input.resp.emplace_back( output_tokens );
}
}
printf("\n");
const auto t_main_end = ggml_time_us();
// write audio to disk
auto waveform = decode_audio( ectx, input.resp );
write_wav_on_disk( waveform, output_response_path );
fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
// cleanup
encodec_free(ectx);
fprintf(stderr, "\n");
llama_perf_sampler_print(smpl);
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_sampler_free(smpl_nar);
llama_sampler_free(smpl_ar);
// encodec_free(ectx);
llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);
return 0;