diff --git a/vall_e.cpp/README.md b/vall_e.cpp/README.md index 443fe45..a5e63d4 100644 --- a/vall_e.cpp/README.md +++ b/vall_e.cpp/README.md @@ -24,11 +24,13 @@ Run `make`. * [x] juggle the output head / classifier properly * [ ] phonemize text * [ ] tokenize phonemes -* [ ] load audio from disk -* [ ] encode audio -* [ ] sum embeddings for the `prom` and prior `resp`s +* [x] load audio from disk +* [x] encode audio +* [x] sum embeddings for the `prom` and prior `resp`s * [x] `AR` sampling * [ ] `NAR-len` demasking sampling * [ ] `NAR` sampling * [ ] decode audio to disk -* [ ] a functional CLI \ No newline at end of file +* [ ] a functional CLI +* [ ] actually make it work + * it seems naively stitching the model together isn't good enough since the output is wrong \ No newline at end of file diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index c11651b..96007d2 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -13,151 +13,140 @@ #include #include -/* Begin cringe so I can access the model's tok_embd */ -// it needs to be copied so the struct layout is exactly as it is under llama.cpp -#define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 +#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd -enum e_model { - MODEL_UNKNOWN, +// stores the raw inputs to be fed +struct input_t { + std::string task = "tts"; + + std::vector phonemes = {}; + llama_token lang = 0; + llama_token rvq_l = 0; + std::vector> prom = {}; + std::vector> resp = {}; +}; +// handles all the cringe logic of slicing embeddings +struct embeddings_t { + int n_embd = 0; + int n_vocab = 0; + float* embds = NULL; + + int text_embd_start = 0; // + int rvq_level_embd_start = 17666; // <|RVQ:0> + int len_embd_start = 17674; // <|len:0|> + int lang_embd_start = 17686; // <|lang:en|> + int task_embd_start = 17692; // <|task:tts|> + int sep_embd_start = 17685; // <|sep|> + int prom_embd_start[8] = { + 256 + (1024 * 0), // <|P|0:0|> + 256 + (1024 * 1), // <|P|1:0|> + 256 + (1024 * 2), // <|P|2:0|> + 256 + (1024 * 3), // <|P|3:0|> + 256 + (1024 * 4), // <|P|4:0|> + 256 + (1024 * 5), // <|P|5:0|> + 256 + (1024 * 6), // <|P|6:0|> + 256 + (1024 * 7), // <|P|7:0|> + }; + int resp_embd_start[9] = { + 8448, // <|AR|0:0|> + 9473, // <|NAR|0:0|> + 10498 + (1024 * 0), // <|NAR|0:1|> + 10498 + (1024 * 1), // <|NAR|1:2|> + 10498 + (1024 * 2), // <|NAR|2:3|> + 10498 + (1024 * 3), // <|NAR|3:4|> + 10498 + (1024 * 4), // <|NAR|4:5|> + 10498 + (1024 * 5), // <|NAR|5:6|> + 10498 + (1024 * 6), // <|NAR|6:7|> + }; + + float* text_embds = NULL; // &embds[text_embd_start * n_embd]; + float* rvq_level_embd = NULL; // &embds[rvq_level_embd_start * n_embd]; + float* len_embd = NULL; // &embds[len_embd_start * n_embd]; + float* lang_embd = NULL; // &embds[lang_embd_start * n_embd]; + float* task_embd = NULL; // &embds[task_embd_start * n_embd]; + float* sep_embd = NULL; // &embds[sep_embd_start * n_embd]; + + float* prom_embds[8] = { + NULL, // &embds[prom_embd_start[0] * n_embd], + NULL, // &embds[prom_embd_start[1] * n_embd], + NULL, // &embds[prom_embd_start[2] * n_embd], + NULL, // &embds[prom_embd_start[3] * n_embd], + NULL, // &embds[prom_embd_start[4] * n_embd], + NULL, // &embds[prom_embd_start[5] * n_embd], + NULL, // &embds[prom_embd_start[6] * n_embd], + NULL, // &embds[prom_embd_start[7] * n_embd], + }; + float* resps_embds[9] = { + NULL, // &embds[resp_embd_start[0] * n_embd], + NULL, // &embds[resp_embd_start[1] * n_embd], + NULL, // &embds[resp_embd_start[2] * n_embd], + NULL, // &embds[resp_embd_start[3] * n_embd], + NULL, // &embds[resp_embd_start[4] * n_embd], + NULL, // &embds[resp_embd_start[5] * n_embd], + NULL, // &embds[resp_embd_start[6] * n_embd], + NULL, // &embds[resp_embd_start[7] * n_embd], + NULL, // &embds[resp_embd_start[8] * n_embd], + }; + + embeddings_t( int n_embd = 0, int n_vocab = 0, float* embds = NULL ) { + init( n_embd, n_vocab, embds ); + } + + void init( int n_embd, int n_vocab, float* embds = NULL ) { + if ( !n_embd || !n_vocab || !embds ) return; + + this->n_embd = n_embd; + this->n_vocab = n_vocab; + this->embds = embds; + + text_embds = &embds[text_embd_start * n_embd]; + rvq_level_embd = &embds[rvq_level_embd_start * n_embd]; + len_embd = &embds[len_embd_start * n_embd]; + lang_embd = &embds[lang_embd_start * n_embd]; + task_embd = &embds[task_embd_start * n_embd]; + sep_embd = &embds[sep_embd_start * n_embd]; + + for ( auto i = 0; i < 8; ++i ) prom_embds[i] = &embds[prom_embd_start[i] * n_embd]; + for ( auto i = 0; i < 9; ++i ) resps_embds[i] = &embds[resp_embd_start[i] * n_embd]; + } }; -enum llm_arch { - LLM_ARCH_UNKNOWN, -}; - -struct llama_hparams_posnet { - uint32_t n_embd; - uint32_t n_layer; -}; - -struct llama_hparams_convnext { - uint32_t n_embd; - uint32_t n_layer; -}; - -struct llama_hparams { - bool vocab_only; - bool rope_finetuned; - bool use_par_res; - bool swin_norm; - - uint32_t n_vocab = 0; - uint32_t n_ctx_train; // context size the model was trained on - uint32_t n_embd; - uint32_t n_embd_features = 0; - uint32_t n_layer; - uint32_t n_rot; - uint32_t n_swa = 0; // sliding window attention (SWA) - uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads - uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head - uint32_t n_expert = 0; - uint32_t n_expert_used = 0; - uint32_t n_vocab_type = 0; // for BERT-style token types - uint32_t n_rel_attn_bkts = 0; - - // for WavTokenizer - struct llama_hparams_posnet posnet; - struct llama_hparams_convnext convnext; - - std::array n_head_arr; - std::array n_head_kv_arr; - std::array n_ff_arr; - - uint32_t n_layer_dense_lead = 0; - uint32_t n_lora_q = 0; - uint32_t n_lora_kv = 0; - uint32_t n_ff_exp = 0; - uint32_t n_ff_shexp = 0; - uint32_t n_expert_shared = 0; - float expert_weights_scale = 0.0; - - float f_norm_eps; - float f_norm_rms_eps; - float f_norm_group_eps; - - uint32_t n_norm_groups; - - float f_attn_logit_softcapping = 50.0f; - float f_final_logit_softcapping = 30.0f; - - // for RWKV - uint32_t rescale_every_n_layers = 0; - uint32_t time_mix_extra_dim = 0; - uint32_t time_decay_extra_dim = 0; - uint32_t wkv_head_size = 0; - - float rope_attn_factor = 1.0f; - float rope_freq_base_train; - float rope_freq_scale_train; - uint32_t n_ctx_orig_yarn; - float rope_yarn_log_mul; - int rope_sections[4]; - - // for State Space Models - uint32_t ssm_d_conv = 0; - uint32_t ssm_d_inner = 0; - uint32_t ssm_d_state = 0; - uint32_t ssm_dt_rank = 0; - bool ssm_dt_b_c_rms = false; - - float f_clamp_kqv = 0.0f; - float f_max_alibi_bias = 0.0f; - float f_logit_scale = 0.0f; - - // Additional scale factors (Granite/Granite MoE) - float f_residual_scale = 0.0f; - float f_embedding_scale = 0.0f; - float f_attention_scale = 0.0f; - - bool causal_attn = true; - bool use_alibi = false; - bool attn_soft_cap = false; - - // needed by encoder-decoder models (e.g. T5, FLAN-T5) - // ref: https://github.com/ggerganov/llama.cpp/pull/8141 - llama_token dec_start_token_id = LLAMA_TOKEN_NULL; - - enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; - enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; - enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; -}; - -struct llama_model { - e_model type = MODEL_UNKNOWN; - llm_arch arch = LLM_ARCH_UNKNOWN; - llama_ftype ftype = LLAMA_FTYPE_ALL_F32; - - std::string name = "n/a"; - - llama_hparams hparams = {}; - llama_vocab vocab; - - struct ggml_tensor * tok_embd = nullptr; -}; -/* End cringe code */ +// maps embeddings easily +std::vector> map_embeddings( const std::vector& tokens, int n_embd, float* embds ) { + std::vector> embedded( tokens.size() ); + for ( auto i = 0; i < tokens.size(); ++i ) { + embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) ); + } + return embedded; +} // handles adding either a token OR the embedding of that token into the batch // this really, really helps avoid needing to abuse the tokenizer -// to-do: handle summing -void batch_add( struct llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool logits = true, const std::vector & seq_ids = {0} ) { +void batch_add( llama_batch& batch, llama_token id, int n_embd, float* embds, llama_pos pos, bool output, const std::vector & seq_ids = {0} ) { GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + // insert raw embedding instead if ( embds ) { - for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[id * n_embd + i]; + // signals to not map the embedding from the array + if ( id < 0 ) for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[i]; + else for ( auto i = 0; i < n_embd; ++i ) batch.embd[batch.n_tokens + i] = embds[id * n_embd + i]; + // insert token (never gets used here) } else { batch.token[batch.n_tokens] = id; } + batch.pos[batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; - } - batch.logits[batch.n_tokens] = logits; + for (size_t i = 0; i < seq_ids.size(); ++i) batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + batch.logits[batch.n_tokens] = output ? 1 : 0; + + // printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens, id, batch.pos[batch.n_tokens], &batch.embd[batch.n_tokens], batch.logits[batch.n_tokens] ); + // printf("[%i] Adding: %i | %i | %p | %i\n", batch.n_tokens-1, id, pos, embds, output ); batch.n_tokens++; } - +// reads a waveform from disk bool read_wav_from_disk(std::string in_path, std::vector & audio_arr) { uint32_t channels; uint32_t sample_rate; @@ -171,6 +160,11 @@ bool read_wav_from_disk(std::string in_path, std::vector & audio_arr) { return false; } + if (sample_rate != 24000) { + fprintf(stderr, "%s: wav file is wrong sample rate\n", __func__); + return false; + } + fprintf(stderr, "\n%s: Number of frames read = %lld.\n", __func__, total_frame_count); audio_arr.resize(total_frame_count); @@ -180,7 +174,7 @@ bool read_wav_from_disk(std::string in_path, std::vector & audio_arr) { return true; } - +// writes a waveform to disk void write_wav_on_disk(std::vector & audio_arr, std::string dest_path) { drwav_data_format format; format.bitsPerSample = 32; @@ -196,8 +190,8 @@ void write_wav_on_disk(std::vector & audio_arr, std::string dest_path) { fprintf(stderr, "%s: Number of frames written = %lld.\n", __func__, frames); } - -std::vector encode_audio( struct encodec_context* ectx, const std::string& path ) { +// reads a waveform from disk then encodes it +std::vector> encode_audio_from_disk( struct encodec_context* ectx, const std::string& path ) { // read audio from disk std::vector wavform; @@ -214,10 +208,34 @@ std::vector encode_audio( struct encodec_context* ectx, const std::stri int32_t* codes_data = encodec_get_codes( ectx ); int n_codes = encodec_get_codes_size( ectx ); + int n_codebooks = 8; + int n_frames = n_codes / n_codebooks; - return std::vector(codes_data, codes_data + n_codes); + std::vector flattened_codes(codes_data, codes_data + n_codes); + std::vector> codes_2ds(8); + + for ( auto l = 0; l < n_codebooks; ++l ) { + codes_2ds[l].resize( n_frames ); + for ( auto i = 0; i < n_frames; ++i ) { + codes_2ds[l][i] = flattened_codes[i + l * n_codebooks]; + } + } + + return codes_2ds; } -std::vector decode_audio( struct encodec_context* ectx, const std::vector& codes ) { +// decodes a 2D codebook into a waveform +std::vector decode_audio( struct encodec_context* ectx, const std::vector>& codes_2d ) { + int n_codebooks = codes_2d.size(); + int n_frames = codes_2d[0].size(); + + std::vector codes( n_frames * n_codebooks ); + + for ( auto l = 0; l < n_codebooks; ++l ) { + for ( auto i = 0; i < n_frames; ++i ) { + codes[i + l * n_codebooks] = codes_2d[l][i]; + } + } + // decompress audio if (!encodec_decompress_audio(ectx, codes.data(), codes.size(), 1)) { printf("%s: error during decompression\n", __func__); @@ -229,22 +247,224 @@ std::vector decode_audio( struct encodec_context* ectx, const std::vector const int audio_size = encodec_get_audio_size(ectx); return std::vector(audio_data, audio_data + audio_size); } + +const int EMBEDDING_MODE_PROM = 0; +const int EMBEDDING_MODE_RESP_AR_NAR = 0; +const int EMBEDDING_MODE_RESP_NAR_LEN = 0; + +const int INFERENCE_MODE_LEN = 0; +const int INFERENCE_MODE_AR = 1; +const int INFERENCE_MODE_NAR_DEMASK = 2; +const int INFERENCE_MODE_NAR = 4; + +const int MODALITY_AR_NAR = 0; +const int MODALITY_NAR_LEN = 0; + +const int MAX_DURATION = 75 * 12; + +// sums embeddings over a 2D "tensor" +std::vector> sum_embeddings( const std::vector>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) { + std::vector> res( input.size() ); + res.resize( input[0].size() ); + for ( auto& e : res ) e.resize( n_embd ); + // iterate through rvq levels (only up to inclusive the target rvq level) + for ( auto l = 0; l < input.size() && l <= rvq_l; ++l ) { + int offset = 0; + // handles the cringe logic I have + if ( mode == EMBEDDING_MODE_RESP_AR_NAR ) { + offset = input.size() == 1 ? 0 : 2; + } else if ( mode == EMBEDDING_MODE_RESP_NAR_LEN ) { + offset = input.size() == 1 ? 1 : 2; + } + // get tokens + auto& tokens = input[l]; + // get output buffer + auto& summed = res[l]; + // embed the current level's tokens + auto embedded = map_embeddings( input[l], n_embd, embds[l + offset] ); + // iterate through embedded tokens + for ( auto i = 0; i < tokens.size(); ++i ) { + // sum with buffer + for ( auto j = 0; j < n_embd; ++j ) summed[j] += embedded[i][j]; + } + } + return res; +} + +void fill_batch( llama_batch& batch, input_t& input, embeddings_t& embeddings_map, int mode ) { + // keeps track of the position for each sequence + size_t pos = 0; + auto n_embd = embeddings_map.n_embd; + + // insert text tokens + for ( auto& id : input.phonemes ) batch_add( batch, id, n_embd, embeddings_map.text_embds, pos++, false ); + batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false ); + pos = 0; + // insert lang token + batch_add( batch, input.lang, n_embd, embeddings_map.lang_embd, pos++, false ); + batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false ); + pos = 0; + // insert rvq level token + batch_add( batch, input.rvq_l, n_embd, embeddings_map.rvq_level_embd, pos++, false ); + batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, false ); + pos = 0; + // insert prom tokens + auto summed_proms_embds = sum_embeddings( input.prom, n_embd, input.rvq_l, embeddings_map.prom_embds ); + for ( auto i = 0; i < summed_proms_embds.size(); ++i ) { + batch_add( batch, -1, n_embd, &summed_proms_embds[i][0], pos++, false ); + } + batch_add( batch, 0, n_embd, embeddings_map.sep_embd, pos++, mode == INFERENCE_MODE_AR ); // set as the last logit if AR + pos = 0; + + // input starting len token + if ( input.task == "len" ) { + batch_add( batch, 0, n_embd, embeddings_map.len_embd, pos++, true ); + pos = 0; + } + + // insert resp tokens + if ( !input.resp.empty() ) { + auto summed_resps_embds = sum_embeddings( input.resp, n_embd, input.rvq_l, embeddings_map.resps_embds, mode == INFERENCE_MODE_AR ? EMBEDDING_MODE_RESP_AR_NAR : EMBEDDING_MODE_RESP_NAR_LEN ); + for ( auto i = 0; i < summed_resps_embds.size(); ++i ) { + batch_add( batch, -1, n_embd, &summed_resps_embds[i][0], pos++, true ); + } + pos = 0; + } +} + +// generation code, should handle all modalities easily +std::vector generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, embeddings_t& embeddings_map, int max_tokens, int mode, bool verbose = true ) { + llama_batch batch = llama_batch_init( 22500, embeddings_map.n_embd, 22500 ); + + // Decoding loop + const auto t_main_start = ggml_time_us(); + int n_decode = 0; + int rvq_l = input.rvq_l; + llama_token stop_token = -1; + + fill_batch( batch, input, embeddings_map, mode ); + + // determine how many logits we need + int n_logits = 0; + for ( auto i = 0; i < batch.n_tokens; ++i ) { + if ( batch.logits[i] ) ++n_logits; + } + + if ( verbose ) printf("Prompt size: %i | Logits: %i\n", batch.n_tokens, n_logits); + + // NAR mode, cap at one step + if ( n_logits > 1 ) { + max_tokens = n_logits; + } + + if ( n_logits == 0 ) { + fprintf(stderr, "%s : no tokens to decode\n", __func__); + return {}; + } + + float* embds = NULL; + int logit_range[2]; + if ( mode == INFERENCE_MODE_AR ) { + logit_range[0] = embeddings_map.resp_embd_start[0]; + logit_range[1] = embeddings_map.resp_embd_start[1]; + + embds = embeddings_map.resps_embds[0]; + + stop_token = embeddings_map.resp_embd_start[1] - 1; // <|AR|0:STOP|> + } else if ( mode == INFERENCE_MODE_NAR_DEMASK ) { + logit_range[0] = embeddings_map.resp_embd_start[1]; + logit_range[1] = embeddings_map.resp_embd_start[2]; + + embds = embeddings_map.resps_embds[1]; + + stop_token = embeddings_map.resp_embd_start[2] - 1; // <|NAR|0:STOP|> + } else if ( mode == INFERENCE_MODE_NAR ) { + logit_range[0] = embeddings_map.resp_embd_start[2+rvq_l]; + logit_range[1] = embeddings_map.resp_embd_start[3+rvq_l]; + + embds = embeddings_map.resps_embds[2]; + } else if ( mode == INFERENCE_MODE_LEN ) { + logit_range[0] = embeddings_map.len_embd_start; + logit_range[1] = embeddings_map.len_embd_start + 11; + + embds = embeddings_map.len_embd; + + stop_token = embeddings_map.len_embd_start + 10; + } + + llama_set_causal_attn( ctx, n_logits == 1 ); + + std::vector output_tokens; + while ( output_tokens.size() < max_tokens ) { + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return output_tokens; + } + n_decode += 1; + + // backwards iterate to start from beginning of sequence + for ( auto i = n_logits; i > 0; --i ) { + // filter logits + auto* logits = llama_get_logits_ith( ctx, -i ); + for ( auto i = 0; i < embeddings_map.n_vocab; ++i ) { + // out of target logit range, set to never happen + if ( i < logit_range[0] || i >= logit_range[1] ) logits[i] = -INFINITY; + } + + // sample the next token + auto t = llama_sampler_sample(smpl, ctx, -i); + + if ( verbose ) { + // print token for debugging + char buf[256]; + llama_token_to_piece( model, t, buf, sizeof(buf), 0, true ); + printf("%s\n", buf ); + } + + // is stop token + if ( t == stop_token ) { + max_tokens = 0; + break; + } + + // offset into range + t -= logit_range[0]; + + output_tokens.emplace_back(t); + batch_add( batch, t, embeddings_map.n_embd, embds, output_tokens.size(), true ); + } + } + const auto t_main_end = ggml_time_us(); + + if ( verbose ) { + printf("\n"); + fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + fprintf(stderr, "\n"); + llama_perf_sampler_print(smpl); + llama_perf_context_print(ctx); + fprintf(stderr, "\n"); + } + + llama_batch_free(batch); + + return output_tokens; +} + int main(int argc, char ** argv) { - bool is_ar = true; // to-do: replace all of this with proper loading code - std::vector phoneme_tokens = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; - llama_token lang_token = 0; - llama_token rvq_level_token = 0; - std::vector> prompt_tokens = { - {780,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,464,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,398,869,639,705,869,917,705,893,215,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,408,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,917,238,745,904,904,904,106,133,1019,1017,1017,395,883,87,519,594,1002,682,996,540,186,1019,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,25,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,159,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,62,172,638,359,562,138,839,846,775,556,688,1006,917,297,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,538,519,762,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,310,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,368,359,519,996,616,464,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,876,876,1017,63,329}, - }; - std::vector> response_tokens = { - {922,395,869,869,354,989,762,762,762,610,975,626,626,866,609,442,762,762,762,610,610,610,610,212,869,869,51,336,352,352,352,570,148,893,76,535,568,568,270,568,568,560,597,86,744,744,744,203,738,408,1019,700,707,92,707,464,744,171,171,159,196,192,697,261,261,568,638,605,904,904,779,832,570,519,223,459,459,459,459,90,90,570,700,53,372,621,610,869,473,869,917,654,473,917,893,654,644,384,558,911,864,521,1,19,665}, - }; + int32_t ngl = 0; + int modality = MODALITY_AR_NAR; + input_t input{}; + embeddings_t embeddings_map{}; + + input.phonemes = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; + std::string vall_e_model_path = "./data/vall_e-q8_0.gguf"; std::string encodec_model_path = "./data/encodec.bin"; - int32_t ngl = 0; - + std::string input_prompt_path = "./data/prom.wav"; + std::string output_response_path = "./data/resp.wav"; // load dynamic backends ggml_backend_load_all(); @@ -255,7 +475,12 @@ int main(int argc, char ** argv) { return 1; } - encodec_set_target_bandwidth(ectx, 24); + encodec_set_target_bandwidth(ectx, 6); + encodec_set_sample_rate(ectx, 24000); + + // load wavform + input.prom = encode_audio_from_disk(ectx, input_prompt_path); + //input.resp = encode_audio_from_disk(ectx, output_response_path); // initialize the models llama_model_params model_params = llama_model_default_params(); @@ -272,9 +497,9 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = 22500; ctx_params.n_batch = 22500; ctx_params.no_perf = false; + ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; - ctx_params.attention_type = is_ar ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL; - + // create two contexts, one's that causally, the other that isn't, because pain llama_context* ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -284,14 +509,19 @@ int main(int argc, char ** argv) { // initialize the sampler auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; - llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler * smpl_ar = llama_sampler_chain_init(sparams); + llama_sampler * smpl_nar = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(20)); + llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(0.9, 20)); + llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0)); + // llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130)); + + llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy()); // prepare batch auto n_embd = llama_n_embd( model ); auto n_vocab = llama_n_vocab( model ); - llama_batch batch = llama_batch_init( ctx_params.n_ctx, n_embd, ctx_params.n_ctx ); // grab input embeddings std::vector embds( n_embd * n_vocab ); @@ -300,151 +530,61 @@ int main(int argc, char ** argv) { if ( ggml_is_quantized(model->tok_embd->type) ) { qtype->to_float(model->tok_embd->data, embds.data(), embds.size()); } + // update mapping + embeddings_map.init( n_embd, n_vocab, embds.data() ); - // to-do: derive these offsets from the tokenizer itself - // to-do: clean this up, probably make it at parity to inputs_to_embeddings - int text_embd_start = 0; // - int rvq_level_embd_start = 17666; // <|RVQ:0> - int len_embd_start = 17674; // <|len:0|> - int lang_embd_start = 17686; // <|lang:en|> - int task_embd_start = 17692; // <|task:tts|> - int sep_embd_start = 17685; // <|sep|> - int prom_embd_start[] = { - 256 + (1024 * 0), // <|P|0:0|> - 256 + (1024 * 1), // <|P|1:0|> - 256 + (1024 * 2), // <|P|2:0|> - 256 + (1024 * 3), // <|P|3:0|> - 256 + (1024 * 4), // <|P|4:0|> - 256 + (1024 * 5), // <|P|5:0|> - 256 + (1024 * 6), // <|P|6:0|> - 256 + (1024 * 7), // <|P|7:0|> - }; - int resp_embd_start[] = { - 8448, // <|AR|0:0|> - 9473, // <|NAR|0:0|> - 10498 + (1024 * 0), // <|NAR|0:1|> - 10498 + (1024 * 1), // <|NAR|1:2|> - 10498 + (1024 * 2), // <|NAR|2:3|> - 10498 + (1024 * 3), // <|NAR|3:4|> - 10498 + (1024 * 4), // <|NAR|4:5|> - 10498 + (1024 * 5), // <|NAR|5:6|> - 10498 + (1024 * 6), // <|NAR|6:7|> - }; - - float* text_embds = &embds[text_embd_start * n_embd]; - float* rvq_level_embd = &embds[rvq_level_embd_start * n_embd]; - float* len_embd = &embds[len_embd_start * n_embd]; - float* lang_embd = &embds[lang_embd_start * n_embd]; - float* task_embd = &embds[task_embd_start * n_embd]; - float* sep_embd = &embds[sep_embd_start * n_embd]; - - float* prom_embds[] = { - &embds[prom_embd_start[0] * n_embd], - &embds[prom_embd_start[1] * n_embd], - &embds[prom_embd_start[2] * n_embd], - &embds[prom_embd_start[3] * n_embd], - &embds[prom_embd_start[4] * n_embd], - &embds[prom_embd_start[5] * n_embd], - &embds[prom_embd_start[6] * n_embd], - &embds[prom_embd_start[7] * n_embd], - }; - float* resps_embds[] = { - &embds[resp_embd_start[0] * n_embd], - &embds[resp_embd_start[1] * n_embd], - &embds[resp_embd_start[2] * n_embd], - &embds[resp_embd_start[3] * n_embd], - &embds[resp_embd_start[4] * n_embd], - &embds[resp_embd_start[5] * n_embd], - &embds[resp_embd_start[6] * n_embd], - &embds[resp_embd_start[7] * n_embd], - &embds[resp_embd_start[8] * n_embd], - }; - - // insert into batch - { - // keeps track of the position for each sequence - size_t pos = 0; - - // insert text tokens - for ( auto& id : phoneme_tokens ) batch_add( batch, id, n_embd, text_embds, pos++, false ); - batch_add( batch, 0, n_embd, sep_embd, pos++, false ); - pos = 0; - // insert lang token - batch_add( batch, lang_token, n_embd, lang_embd, pos++, false ); - batch_add( batch, 0, n_embd, sep_embd, pos++, false ); - pos = 0; - // insert rvq level token - batch_add( batch, rvq_level_token, n_embd, rvq_level_embd, pos++, false ); - batch_add( batch, 0, n_embd, sep_embd, pos++, false ); - pos = 0; - // insert prom tokens - // to-do: handle summing - for ( auto l = 0; l < prompt_tokens.size(); ++l ) { - for ( auto& id : prompt_tokens[l] ) batch_add( batch, id, n_embd, prom_embds[l], pos++, false ); - } - batch_add( batch, 0, n_embd, sep_embd, pos++, is_ar ); - pos = 0; - - // fill in masked tokens - if ( !is_ar ) { - for ( auto i = 0; i < response_tokens[0].size(); ++i ) batch_add( batch, response_tokens[0][i], n_embd, resps_embds[1], pos++, true ); - } - pos = 0; - } - - // Decoding loop - const auto t_main_start = ggml_time_us(); - int n_decode = 0; - - // to-do: handle other levels - std::vector resps_tokens; - while ( resps_tokens.size() < 32 ) { - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); - return 1; - } - n_decode += 1; - - // align to AR's classifier - // to-do: derive from tokenizer - int range[] = { resp_embd_start[0], resp_embd_start[1] }; - auto* logits = llama_get_logits_ith( ctx, -1 ); - for ( auto i = 0; i < n_vocab; ++i ) { - if ( i < range[0] || i >= range[1] ) { - logits[i] = -INFINITY; + // inference + std::vector output_tokens; + // NAR-len demasking + if ( modality == MODALITY_NAR_LEN ) { + // inference len + input.task = "len"; + output_tokens = generate( ctx, model, smpl_nar, input, embeddings_map, 5, INFERENCE_MODE_LEN ); + int len = 0; { + int digit = 1; + for (int i = output_tokens.size() - 1; i >= 0; i--) { + len += output_tokens[i] * digit; + digit *= 10; } } + // cap for now + if ( len > MAX_DURATION ) len = MAX_DURATION; - // sample the next token - auto t = llama_sampler_sample(smpl, ctx, -1); - - // is stop token - if ( t == resp_embd_start[1] - 1 ) { // <|AR|0:STOP|> - break; + // fill with mask tokens + input.resp.resize(1); + for ( auto i = 0; i < len; ++i ) { + input.resp[0].emplace_back( embeddings_map.resp_embd_start[3] - 1 ); // fill with masked tokens } - char buf[256]; - llama_token_to_piece( model, t, buf, sizeof(buf), 0, true ); - printf("%s\n", buf ); - - batch_add( batch, 0, n_embd, resps_embds[0], resps_tokens.size(), true ); - resps_tokens.emplace_back(t); + // inference NAR-len 0 + input.task = "tts"; + for ( auto l = 0; l < 8; ++l ) { + input.rvq_l = l; + output_tokens = generate( ctx, model, smpl_nar, input, embeddings_map, 5, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR ); + input.resp.emplace_back( output_tokens ); + } + // AR+NAR + } else if ( modality == MODALITY_AR_NAR ){ + input.task = "tts"; + for ( auto l = 0; l < 8; ++l ) { + input.rvq_l = l; + output_tokens = generate( ctx, model, l == 0 ? smpl_ar : smpl_nar, input, embeddings_map, l == 0 ? MAX_DURATION : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR ); + input.resp.emplace_back( output_tokens ); + } } - printf("\n"); - const auto t_main_end = ggml_time_us(); + // write audio to disk + auto waveform = decode_audio( ectx, input.resp ); + write_wav_on_disk( waveform, output_response_path ); - fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", - __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + // cleanup + encodec_free(ectx); - fprintf(stderr, "\n"); - llama_perf_sampler_print(smpl); - llama_perf_context_print(ctx); - fprintf(stderr, "\n"); - - // encodec_free(ectx); - llama_sampler_free(smpl); + llama_sampler_free(smpl_nar); + llama_sampler_free(smpl_ar); + llama_free(ctx); + llama_free_model(model); return 0;