more work on vall_e.cpp (some more cleanup, NAR-len demasking, but still need to iron out some kinks)
This commit is contained in:
parent
a6945f981d
commit
6ecdb715b6
|
@ -15,17 +15,8 @@ Run `make`.
|
|||
### Required Modifications
|
||||
|
||||
[`encodec.cpp`](https://github.com/e-c-k-e-r/encodec.cpp) requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working.
|
||||
[`llama.cpp`](https://github.com/e-c-k-e-r/llama.cpp) *might* not require any modifications, but:
|
||||
* `llm.build_vall_e` can mostly copy `llm.build_llama`, but with:
|
||||
* `KQ_mask = build_inp_KQ_mask( lctx.cparams.causal_attn )`
|
||||
* a unified output head (pain)
|
||||
* OR adjusting the `model.output` to the correct classifier head (better option)
|
||||
* OR slicing that tensor with the right range (`ggml_view_2d` confuses me)
|
||||
* both require also require `*const_cast<uint32_t*>(&ctx->model.hparams.n_vocab) = output->ne[1];` because the logits are tied to `n_vocab`
|
||||
* commenting out `GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());` because grabbing embeddings/classifiers require using `bid` to trick it thinking it's part of a layer
|
||||
* some helper functions to retrieve the embeddings tensor from the model
|
||||
* some helper functions to set the target classifier head
|
||||
* some fix for `GGML_ASSERT(mask->ne[0] == a->ne[0])` when using a non-causal attention mask (or I can test on the model that had a causal NAR......)
|
||||
|
||||
[`llama.cpp`](https://github.com/e-c-k-e-r/llama.cpp) *might* not require any modifications, but implementing `LLM_ARCH_VALL_E` requires some surgery.
|
||||
|
||||
## To-Do
|
||||
|
||||
|
@ -46,11 +37,11 @@ Run `make`.
|
|||
* [x] `AR` sampling
|
||||
* currently need a model that didn't regress with the `AR:0:0` output
|
||||
* [ ] working `NAR-len` output
|
||||
* [ ] `NAR-len` sampling
|
||||
* currently cannot inference with non-causal_attn
|
||||
* [x] `NAR-len` sampling
|
||||
* need to assert that a non-causal mask is used
|
||||
* [ ] working `NAR` output
|
||||
* [x] `NAR` sampling
|
||||
* currently cannot inference with non-causal_attn
|
||||
* need to assert that a non-causal mask is used
|
||||
* [x] decode audio to disk
|
||||
* [ ] a functional CLI
|
||||
* [ ] actually make it work
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "vall_e.h"
|
||||
|
||||
|
||||
|
||||
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
|
||||
#define LLAMA_CPP_USE_VALL_E_ARCH 1 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
|
||||
|
||||
#if !LLAMA_CPP_EXTENDED
|
||||
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
|
||||
#endif
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
|
||||
ranges_t io_ranges[] = {
|
||||
{ "text", 0, 256, 9, },
|
||||
|
@ -39,7 +36,7 @@ ranges_t io_ranges[] = {
|
|||
{ "resps|NAR:0: 16677, 17702, 8,0", },
|
||||
};
|
||||
|
||||
std::vector<float> read_2d_tensor( struct ggml_tensor* tensor ) {
|
||||
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
|
||||
size_t size = tensor->ne[0] * tensor->ne[1];
|
||||
std::vector<float> res( size );
|
||||
|
||||
|
@ -55,29 +52,29 @@ std::vector<float> read_2d_tensor( struct ggml_tensor* tensor ) {
|
|||
}
|
||||
|
||||
|
||||
struct ggml_tensor * vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
|
||||
struct ggml_tensor * VALL_E_API vall_e_get_prom_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
|
||||
return userdata.prom_embds[idx];
|
||||
}
|
||||
struct ggml_tensor * vall_e_get_resp_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
|
||||
struct ggml_tensor * VALL_E_API vall_e_get_resp_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
|
||||
return userdata.resp_embds[idx];
|
||||
}
|
||||
struct ggml_tensor * vall_e_get_aux_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
|
||||
struct ggml_tensor * VALL_E_API vall_e_get_aux_embds( llama_vall_e_userdata& userdata, int32_t idx ) {
|
||||
return userdata.aux_embds[idx];
|
||||
}
|
||||
|
||||
|
||||
const embeddings_t& vall_e_inputs_map_get_embeddings( inputs_map_t& inputs_map, const std::string& name ) {
|
||||
const embeddings_t& VALL_E_API vall_e_inputs_map_get_embeddings( inputs_map_t& inputs_map, const std::string& name ) {
|
||||
return inputs_map.embds[name];
|
||||
}
|
||||
const float* vall_e_inputs_map_get_embeddings_p( inputs_map_t& inputs_map, const std::string& name ) {
|
||||
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( inputs_map_t& inputs_map, const std::string& name ) {
|
||||
return inputs_map.embds[name].embds.data();
|
||||
}
|
||||
|
||||
int32_t vall_e_inputs_map_get_classifier_idx( inputs_map_t& inputs_map, const std::string& name ) {
|
||||
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( inputs_map_t& inputs_map, const std::string& name ) {
|
||||
return inputs_map.embds[name].range.classifier_idx;
|
||||
}
|
||||
|
||||
void vall_e_inputs_map_init( inputs_map_t& inputs_map, llama_model* model ) {
|
||||
void VALL_E_API vall_e_inputs_map_init( inputs_map_t& inputs_map, llama_model* model ) {
|
||||
auto n_embd = llama_n_embd( model );
|
||||
auto n_vocab = llama_n_vocab( model );
|
||||
|
||||
|
@ -146,7 +143,7 @@ void vall_e_inputs_map_init( inputs_map_t& inputs_map, llama_model* model ) {
|
|||
}
|
||||
|
||||
// maps embeddings easily
|
||||
std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>& tokens, int n_embd, const float* embds ) {
|
||||
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<llama_token>& tokens, int n_embd, const float* embds ) {
|
||||
std::vector<std::vector<float>> embedded( tokens.size() );
|
||||
for ( auto i = 0; i < tokens.size(); ++i ) {
|
||||
embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) );
|
||||
|
@ -156,7 +153,7 @@ std::vector<std::vector<float>> map_embeddings( const std::vector<llama_token>&
|
|||
|
||||
// handles adding either a token OR the embedding of that token into the batch
|
||||
// this really, really helps avoid needing to abuse the tokenizer
|
||||
void batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids ) {
|
||||
void VALL_E_API batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids ) {
|
||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
||||
|
||||
// insert raw embedding instead
|
||||
|
@ -181,7 +178,7 @@ void batch_add( llama_batch& batch, llama_token id, int n_embd, const float* emb
|
|||
batch.n_tokens++;
|
||||
}
|
||||
// reads a waveform from disk
|
||||
bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
|
||||
bool VALL_E_API read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
|
||||
uint32_t channels;
|
||||
uint32_t sample_rate;
|
||||
drwav_uint64 total_frame_count;
|
||||
|
@ -209,7 +206,7 @@ bool read_wav_from_disk(std::string in_path, std::vector<float> & audio_arr) {
|
|||
return true;
|
||||
}
|
||||
// writes a waveform to disk
|
||||
void write_wav_on_disk(std::vector<float> & audio_arr, std::string dest_path) {
|
||||
void VALL_E_API write_wav_on_disk(std::vector<float> & audio_arr, std::string dest_path) {
|
||||
drwav_data_format format;
|
||||
format.bitsPerSample = 32;
|
||||
format.sampleRate = 24000;
|
||||
|
@ -225,7 +222,7 @@ void write_wav_on_disk(std::vector<float> & audio_arr, std::string dest_path) {
|
|||
fprintf(stderr, "%s: Number of frames written = %lld.\n", __func__, frames);
|
||||
}
|
||||
// reads a waveform from disk then encodes it
|
||||
std::vector<std::vector<int32_t>> encode_audio_from_disk( struct encodec_context* ectx, const std::string& path ) {
|
||||
std::vector<std::vector<int32_t>> VALL_E_API encode_audio_from_disk( struct encodec_context* ectx, const std::string& path ) {
|
||||
// read audio from disk
|
||||
std::vector<float> wavform;
|
||||
|
||||
|
@ -258,7 +255,7 @@ std::vector<std::vector<int32_t>> encode_audio_from_disk( struct encodec_context
|
|||
return codes_2ds;
|
||||
}
|
||||
// decodes a 2D codebook into a waveform
|
||||
std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes_2d ) {
|
||||
std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes_2d ) {
|
||||
int n_codebooks = codes_2d.size();
|
||||
int n_frames = codes_2d[0].size();
|
||||
|
||||
|
@ -283,7 +280,7 @@ std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector
|
|||
}
|
||||
|
||||
// sums embeddings over a 2D "tensor"
|
||||
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, const float** embds, int mode ) {
|
||||
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, const float** embds, int mode ) {
|
||||
std::vector<std::vector<float>> res( input.size() );
|
||||
res.resize( input[0].size() );
|
||||
for ( auto& e : res ) e.resize( n_embd );
|
||||
|
@ -311,7 +308,22 @@ std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<ll
|
|||
return res;
|
||||
}
|
||||
|
||||
void fill_batch( llama_batch& batch, input_t& input, inputs_map_t& inputs_map, int mode ) {
|
||||
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
|
||||
std::vector<float> res( n_logits, 0.0f );
|
||||
float denom = 0.0f;
|
||||
|
||||
for ( auto i = 0; i < n_logits; ++i ) {
|
||||
denom += exp( logits[i] );
|
||||
}
|
||||
// to-do: assert denom != 0.0f
|
||||
for ( auto i = 0; i < n_logits; ++i ) {
|
||||
res[i] = logits[i] / denom;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void VALL_E_API fill_batch( llama_batch& batch, input_t& input, inputs_map_t& inputs_map, int mode ) {
|
||||
// keeps track of the position for each sequence
|
||||
size_t pos = 0;
|
||||
auto n_embd = inputs_map.n_embd;
|
||||
|
@ -382,48 +394,42 @@ void fill_batch( llama_batch& batch, input_t& input, inputs_map_t& inputs_map, i
|
|||
}
|
||||
|
||||
// generation code, should handle all modalities easily
|
||||
std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, inputs_map_t& inputs_map, int max_tokens, int mode, bool verbose ) {
|
||||
llama_batch batch = llama_batch_init( 22500, inputs_map.n_embd, 22500 );
|
||||
|
||||
// Decoding loop
|
||||
const auto t_main_start = ggml_time_us();
|
||||
int n_decode = 0;
|
||||
std::vector<llama_token> VALL_E_API generate( llama_context* ctx, llama_model* model, llama_sampler* smpl, input_t& input, inputs_map_t& inputs_map, int max_tokens, int mode, bool verbose ) {
|
||||
int rvq_l = input.rvq_l;
|
||||
llama_token stop_token = -1;
|
||||
int n_decode = 0; // number of tokens decoded
|
||||
int n_outputs = 0; // number of output tokens to expect
|
||||
int n_vocab = 0;
|
||||
int n_embd = 0;
|
||||
bool causal = true; // sample autoregressively or not
|
||||
const float* embds = NULL; // embeddings to map output tokens through
|
||||
ranges_t range; // I/O range
|
||||
|
||||
// create batch (targetting embeddings instead of tokens)
|
||||
llama_batch batch = llama_batch_init( CTX_SIZE, inputs_map.n_embd, CTX_SIZE );
|
||||
fill_batch( batch, input, inputs_map, mode );
|
||||
|
||||
// determine how many logits we need
|
||||
int n_logits = 0;
|
||||
// determine how many outputs we need
|
||||
for ( auto i = 0; i < batch.n_tokens; ++i ) {
|
||||
if ( batch.logits[i] ) ++n_logits;
|
||||
if ( batch.logits[i] ) ++n_outputs;
|
||||
}
|
||||
if ( verbose ) printf("Prompt size: %i | Outputs: %i\n", batch.n_tokens, n_outputs);
|
||||
|
||||
if ( verbose ) printf("Prompt size: %i | Outputs: %i\n", batch.n_tokens, n_logits);
|
||||
|
||||
// NAR mode, cap at one step
|
||||
if ( n_logits > 1 ) {
|
||||
max_tokens = n_logits;
|
||||
}
|
||||
|
||||
if ( n_logits == 0 ) {
|
||||
// bail out
|
||||
if ( n_outputs == 0 ) {
|
||||
fprintf(stderr, "%s : no tokens to decode\n", __func__);
|
||||
return {};
|
||||
}
|
||||
causal = n_outputs == 1;
|
||||
|
||||
const float* embds = NULL;
|
||||
ranges_t range;
|
||||
|
||||
// AR mode
|
||||
std::string embd_name = "";
|
||||
if ( mode == INFERENCE_MODE_AR ) {
|
||||
auto& embeddings = vall_e_inputs_map_get_embeddings(inputs_map, "resps|AR:0:0");
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
stop_token = range.end - range.start - 1;
|
||||
|
||||
printf("Generating in %s (%i) mode (%i:%i) (%i)\n", "AR", range.classifier_idx, range.start, range.end, stop_token);
|
||||
embd_name = "resps|AR:0:0";
|
||||
// NAR mode
|
||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||
std::string k_embds[] = {
|
||||
"resps|NAR:0:0", // invalid
|
||||
"resps|NAR:0:0", // invalid, should never be picked
|
||||
"resps|NAR:0:1",
|
||||
"resps|NAR:1:2",
|
||||
"resps|NAR:2:3",
|
||||
|
@ -432,88 +438,237 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
|
|||
"resps|NAR:5:6",
|
||||
"resps|NAR:6:7",
|
||||
};
|
||||
auto& embeddings = vall_e_inputs_map_get_embeddings(inputs_map, k_embds[rvq_l]);
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
|
||||
printf("Generating in %s (%i) mode (%i:%i)\n", "NAR", range.classifier_idx, range.start, range.end);
|
||||
embd_name = k_embds[rvq_l];
|
||||
// duration inferencing mode
|
||||
} else if ( mode == INFERENCE_MODE_LEN ) {
|
||||
auto& embeddings = vall_e_inputs_map_get_embeddings(inputs_map, "len");
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
stop_token = range.end - range.start - 1;
|
||||
|
||||
printf("Generating in %s (%i) mode (%i:%i) (%i)\n", "len", range.classifier_idx, range.start, range.end, stop_token);
|
||||
embd_name = "len";
|
||||
// NAR-len (demasking) inferencing mode
|
||||
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
|
||||
auto& embeddings = vall_e_inputs_map_get_embeddings(inputs_map, "resps|NAR:0:0");
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
|
||||
printf("Generating in %s (%i) mode (%i:%i)\n", "NAR-len", range.classifier_idx, range.start, range.end);
|
||||
embd_name = "resps|NAR:0:0";
|
||||
}
|
||||
|
||||
auto& embeddings = vall_e_inputs_map_get_embeddings(inputs_map, embd_name);
|
||||
range = embeddings.range;
|
||||
embds = embeddings.embds.data();
|
||||
n_embd = embeddings.n_embd;
|
||||
n_vocab = embeddings.n_vocab;
|
||||
stop_token = range.end - range.start - 1;
|
||||
|
||||
printf("Generating in %s (%i) mode (%i:%i) (%i)\n", embd_name.c_str(), range.classifier_idx, range.start, range.end, stop_token);
|
||||
|
||||
// update model's output heads / causal mode
|
||||
#if LLAMA_CPP_USE_VALL_E_ARCH
|
||||
auto& userdata = *llama_get_vall_e_userdata( model );
|
||||
llama_set_output_head( model, userdata.heads[range.classifier_idx] );
|
||||
#endif
|
||||
llama_set_causal_attn( ctx, n_logits == 1 );
|
||||
llama_set_causal_attn( ctx, causal );
|
||||
// to-do: fix GGML_ASSERT(mask->ne[0] == a->ne[0])
|
||||
|
||||
std::vector<llama_token> output_tokens;
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
// if INFERENCE_MODE_AR || INFERENCE_MODE_LEN
|
||||
if ( causal ) {
|
||||
output_tokens.reserve(max_tokens);
|
||||
if ( verbose ) {
|
||||
printf("[");
|
||||
fflush(stdout);
|
||||
}
|
||||
while ( output_tokens.size() < max_tokens ) {
|
||||
if ( llama_decode(ctx, batch) ) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return output_tokens;
|
||||
}
|
||||
std::vector<llama_token> current_tokens;
|
||||
// backwards iterate to start from beginning of sequence
|
||||
for ( auto i = n_logits; i > 0; --i ) {
|
||||
// filter logits
|
||||
auto* logits = llama_get_logits_ith( ctx, -i );
|
||||
|
||||
// ensures only tokens within our designated range are used
|
||||
#if !LLAMA_CPP_USE_VALL_E_ARCH
|
||||
auto* logits = llama_get_logits_ith( ctx, -1 );
|
||||
for ( auto i = 0; i < inputs_map.n_vocab; ++i ) {
|
||||
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
|
||||
}
|
||||
#endif
|
||||
// sample token
|
||||
auto t = llama_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
// is stop token
|
||||
if ( t == stop_token ) {
|
||||
break;
|
||||
}
|
||||
|
||||
// store token
|
||||
output_tokens.emplace_back(t);
|
||||
// update batch with token
|
||||
batch_add( batch, t, inputs_map.n_embd, embds, output_tokens.size(), true );
|
||||
if ( verbose ) {
|
||||
printf("%i, ", t);
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
if ( verbose ) {
|
||||
printf("]\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
} else if ( mode == INFERENCE_MODE_NAR_DEMASK ) {
|
||||
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
|
||||
const llama_token MASK_TOKEN = 1024; // token value for masking
|
||||
const float PI = 3.141592653589793f;
|
||||
// to-do: derive from sampling arguments
|
||||
int32_t steps = 30; // number of demasking steps
|
||||
int32_t seq_len = n_outputs;
|
||||
float temperature = 1.5f;
|
||||
float cfg_strength = 2.5f;
|
||||
|
||||
// fill with masked tokens
|
||||
output_tokens.clear();
|
||||
output_tokens.resize(n_outputs, MASK_TOKEN);
|
||||
|
||||
// for CFG
|
||||
input_t null_input{};
|
||||
null_input.phn = {1, 2}; // <bos></eos>
|
||||
null_input.resp.resize(1);
|
||||
|
||||
llama_batch null_batch = llama_batch_init( CTX_SIZE, inputs_map.n_embd, CTX_SIZE );
|
||||
|
||||
// token scores to reference for masking
|
||||
std::vector<float> scores(n_outputs, 1.0);
|
||||
|
||||
// do one step on many tokens
|
||||
for ( auto step = 0; step < steps; ++step ) {
|
||||
if ( verbose ) {
|
||||
printf("[%i/%i] [", step, steps);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
float timestep = (step+1) / steps; // to-do: align with torch.linspace
|
||||
float annealing = 1.0f - timestep;
|
||||
float noise_p = cos( timestep * PI * 0.5f );
|
||||
float remask_p = 0.5f / steps;
|
||||
int32_t n_masked_tokens = std::min(int(noise_p * seq_len), 1);
|
||||
float sampling_temperature = temperature * annealing;
|
||||
float sampling_cfg_strength = timestep * cfg_strength;
|
||||
|
||||
std::vector<bool> is_masked(n_outputs, false);
|
||||
std::vector<int32_t> masked_indices;
|
||||
masked_indices.reserve(n_masked_tokens);
|
||||
std::vector<float> sorted = scores;
|
||||
std::sort(sorted.begin(), sorted.end());
|
||||
masked_indices.insert( masked_indices.end(), sorted.begin(), sorted.begin() + n_masked_tokens );
|
||||
|
||||
// mask off tokens
|
||||
for ( auto& idx : masked_indices ) {
|
||||
output_tokens[idx] = MASK_TOKEN;
|
||||
}
|
||||
// update token mask
|
||||
for ( auto i = 0; i < n_outputs; ++i ) {
|
||||
is_masked[i] = output_tokens[i] == MASK_TOKEN;
|
||||
}
|
||||
|
||||
// update batch
|
||||
// to-do: only update the embeddings instead
|
||||
batch.n_tokens = 0;
|
||||
input.resp[0] = output_tokens;
|
||||
fill_batch( batch, input, inputs_map, mode );
|
||||
// update null batch
|
||||
null_input.resp[0] = output_tokens;
|
||||
null_batch.n_tokens = 0;
|
||||
fill_batch( null_batch, input, inputs_map, mode );
|
||||
|
||||
// to-do: update sampling temperature
|
||||
|
||||
// cfg decode
|
||||
if ( llama_decode(ctx, null_batch) ) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return output_tokens;
|
||||
}
|
||||
// copy null probabilities
|
||||
std::vector<float> null_logits(n_outputs * n_vocab, -INFINITY);
|
||||
// to-do: copy once
|
||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||
memcpy( &null_logits[idx * n_vocab], llama_get_logits_ith( ctx, null_batch.n_tokens - n_outputs + idx ), sizeof(float) * n_vocab );
|
||||
}
|
||||
|
||||
// decode
|
||||
if ( llama_decode(ctx, batch) ) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return output_tokens;
|
||||
}
|
||||
// to-do: figure out why all logits are the same for each token......
|
||||
// "reverse" iterate from backwards indexing
|
||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||
// skip if not masked
|
||||
if ( !is_masked[idx] )
|
||||
continue;
|
||||
// ensures only tokens within our designated range are used
|
||||
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
|
||||
auto* null_logit = &null_logits[idx];
|
||||
|
||||
#if !LLAMA_CPP_USE_VALL_E_ARCH
|
||||
for ( auto i = 0; i < inputs_map.n_vocab; ++i ) {
|
||||
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
|
||||
}
|
||||
#endif
|
||||
// perform softmax before modifying logits
|
||||
std::vector<float> softmaxed = soft_max( n_vocab, logits );
|
||||
|
||||
// sample the next token
|
||||
printf("%i: %p\n [", -i, logits );
|
||||
for ( auto i = 0; i < 1025; ++i ) {
|
||||
printf("%f, ", logits[i]);
|
||||
// perform CFG sampling
|
||||
for ( auto i = 0; i < n_vocab; ++i ) {
|
||||
logits[i] = null_logit[i] + (logits[i] - null_logit[i]) * cfg_strength;
|
||||
}
|
||||
printf("]\n");
|
||||
auto t = llama_sampler_sample(smpl, ctx, -i);
|
||||
//printf("%i: [%i]: %f | %p\n", -i, t, logits[t], logits );
|
||||
|
||||
// offset back into range
|
||||
// sample ith token
|
||||
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx );
|
||||
// store token if it was masked
|
||||
output_tokens[idx] = t;
|
||||
// update score if it was masked
|
||||
scores[idx] = 1.0f - softmaxed[t]; // invert so we pick the worst tokens later
|
||||
if ( verbose ) {
|
||||
printf("%i, ", t);
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
if ( verbose ) {
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
} else if ( mode == INFERENCE_MODE_NAR ) {
|
||||
// to-do: assert n_outputs == input.resp[rvq_l-1].size()
|
||||
output_tokens.reserve(n_outputs);
|
||||
// do one step on many tokens
|
||||
if ( llama_decode(ctx, batch) ) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return output_tokens;
|
||||
}
|
||||
// to-do: figure out why all logits are the same for each token......
|
||||
// "reverse" iterate from backwards indexing
|
||||
if ( verbose ) {
|
||||
printf("[");
|
||||
fflush(stdout);
|
||||
}
|
||||
for ( auto idx = 0; idx < n_outputs; ++idx ) {
|
||||
// ensures only tokens within our designated range are used
|
||||
#if !LLAMA_CPP_USE_VALL_E_ARCH
|
||||
t -= range.start;
|
||||
auto* logits = llama_get_logits_ith( ctx, batch.n_tokens - n_outputs + idx );
|
||||
for ( auto i = 0; i < inputs_map.n_vocab; ++i ) {
|
||||
if ( i < range.start || i >= range.end ) logits[i] = -INFINITY;
|
||||
}
|
||||
#endif
|
||||
|
||||
n_decode += 1;
|
||||
|
||||
// is stop token
|
||||
if ( t == stop_token ) {
|
||||
printf("STOPPED\n");
|
||||
max_tokens = 0;
|
||||
break;
|
||||
}
|
||||
// sample ith token
|
||||
auto t = llama_sampler_sample(smpl, ctx, batch.n_tokens - n_outputs + idx);
|
||||
|
||||
// store token
|
||||
current_tokens.emplace_back(t);
|
||||
// update batch with token
|
||||
batch_add( batch, t, inputs_map.n_embd, embds, output_tokens.size(), true );
|
||||
output_tokens.emplace_back(t);
|
||||
if ( verbose ) {
|
||||
printf("%i, ", t);
|
||||
fflush(stdout);
|
||||
}
|
||||
printf("%s: Tokens: [", __func__);
|
||||
for ( auto& token : current_tokens ) {
|
||||
printf("%i, ", token);
|
||||
}
|
||||
if ( verbose ) {
|
||||
printf("]\n");
|
||||
|
||||
output_tokens.insert(output_tokens.end(), current_tokens.begin(), current_tokens.end());
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
if ( verbose ) {
|
||||
|
@ -535,7 +690,7 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
|
|||
int main( int argc, char** argv ) {
|
||||
// to-do: replace all of this with proper loading code
|
||||
int32_t ngl = 0;
|
||||
int modality = MODALITY_AR_NAR;
|
||||
int modality = MODALITY_NAR_LEN;
|
||||
input_t input{};
|
||||
inputs_map_t inputs_map{};
|
||||
|
||||
|
@ -632,7 +787,7 @@ int main( int argc, char** argv ) {
|
|||
// NAR-len demasking
|
||||
if ( modality == MODALITY_NAR_LEN ) {
|
||||
// inference len
|
||||
int len = 0;
|
||||
int len = 290;
|
||||
if ( !len ) {
|
||||
input.task = "len";
|
||||
output_tokens = generate( ctx, model, smpl_nar, input, inputs_map, 5, INFERENCE_MODE_LEN );
|
||||
|
|
|
@ -1,23 +1,24 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-vocab.h"
|
||||
#include "llama.h"
|
||||
#include "encodec.h"
|
||||
|
||||
#include "dr_wav.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
// to-do: copy over import/export stuff from engine project (because I don't remember how I set it up in <uf/config.h>)
|
||||
#define VALL_E_API
|
||||
|
||||
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
|
||||
#define LLAMA_CPP_USE_VALL_E_ARCH 1 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
|
||||
|
||||
#if !LLAMA_CPP_EXTENDED
|
||||
#include "_llama.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
|
||||
#endif
|
||||
|
||||
// to-do: clean up spaghetti enums
|
||||
const int EMBEDDING_MODE_PROM = 0;
|
||||
const int EMBEDDING_MODE_RESP_AR_NAR = 1;
|
||||
|
@ -106,6 +107,7 @@ struct inputs_map_t {
|
|||
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor );
|
||||
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<llama_token>& tokens, int n_embd, const float* embds );
|
||||
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
|
||||
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits );
|
||||
|
||||
// batch and inferencing
|
||||
void VALL_E_API batch_add( llama_batch& batch, llama_token id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} );
|
||||
|
|
Loading…
Reference in New Issue
Block a user