This commit is contained in:
parent
25a02f2c3f
commit
9b0d2ccbe1
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -10,6 +10,6 @@ __pycache__
|
|||
/.nltk
|
||||
/vall_e.cpp/data
|
||||
/vall_e.cpp/include
|
||||
/vall_e.cpp/libs
|
||||
/vall_e.cpp/lib
|
||||
/vall_e.cpp/*.o
|
||||
/vall_e.cpp/vall_e
|
||||
/vall_e.cpp/vall_e
|
||||
|
|
|
@ -121,6 +121,9 @@ With attention-based transformers, most embeddings can serve as a token itself a
|
|||
|
||||
Other solutions such as TorToiSe makes use of additional embeddings/classifiers for each portion of the sequence as well.
|
||||
|
||||
Other solutions will rely on conditioning latents or extracted features as the input. This *technically* isn't necessary since portions of the model seem to be allocated as an encoder anyways from the embeddings to some arbitrary depth, and as a decoder from some arbitrary depth to the output heads.
|
||||
* This might also mean it makes more sense to increase the model's size in-post by injecting new layers in the middle where it's outside these pseudo-encoder/decoder layers where it won't make any difference.
|
||||
|
||||
### Classifiers
|
||||
|
||||
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
|
||||
|
@ -152,7 +155,7 @@ In reality, this seems to help govern the accent / general mannerisms associated
|
|||
* Consequently, since this does tie to accents more, ***extreme*** attention is to be paid to the dialects being trained against, instead of naively grouping, say, all of Spanish to one language code.
|
||||
* unfortunately, this does mean that audio annotated as English is dialect/accent-agnostic, per the dataset.
|
||||
|
||||
This embedding probably helps the model with being able to perform cross-lingual outputs, but I did not do any experimentations on a model without this, as the reference `ar+nar-llama-8` was trained with this from the beginning with the small Japanese in my dataset anyhow (and maybe the `ar+nar-retnet-8` experiment).
|
||||
Some checkpoints of the model needs this for cross-lingual output, but the current checkpoints of the model doesn't seem to do this due to the attention heads deriving the language/accent from the phoneme sequences themselves rather than the language token due to a careless oversight.
|
||||
|
||||
#### Tone Embedding
|
||||
|
||||
|
@ -162,6 +165,8 @@ Should, since I do not actually make use of this anywhere, and the model is not
|
|||
|
||||
This should most definitely help the model identify tone strongly even without needing to annotate for it, but it does an adequate job already with maintaining tone from a given input prompt.
|
||||
|
||||
I imagine, like language/accent, this gets derived from the phoneme sequence itself rather than a guidance token.
|
||||
|
||||
### Audio Embeddings
|
||||
|
||||
However, due to the nature of the encoded audio, embedding the audio tokens requires the dark arts, as we use audio both as an input prompt (`prom`) for guidance, and as an output response (`resp`).
|
||||
|
@ -230,12 +235,16 @@ In practice, this task is already implemented by providing the input audio to de
|
|||
I imagine training for this task will better help the model understand what is noise and what isn't, and can better strongly-er map utterances from the input audio prompt to use in the output, delivering better prompt adherance.
|
||||
* This also might help serve in helping the model identify effects applied to an utterance, and being able to maintain it in normal `tts` tasks, such as reverb or the audio quality itself (the "acoustic environment").
|
||||
|
||||
This task can be briefly trained for decent results in-post.
|
||||
|
||||
##### Speech Removal
|
||||
|
||||
This task `sr` aims to remove speech from a given audio, effectively serving as the reverse of denoising.
|
||||
|
||||
As state above, this should help the model better identify what is noise and what isn't.
|
||||
|
||||
This task can be briefly trained for decent results in-post.
|
||||
|
||||
##### Target Speech Extraction
|
||||
|
||||
This task `tse` aims to "extract" an utterance from audio containing other speakers, effective diarizing an utterance.
|
||||
|
@ -258,6 +267,8 @@ The length predictor `len` task is required for a pure NAR model.
|
|||
|
||||
This task will naively output a zero, then the length in base-10, followed by a stop token.
|
||||
|
||||
This works because the model can already derive the length of a sequence when autoregressively decoding through the probability of emitting a `<stop>` token.
|
||||
|
||||
#### Speech-to-Text
|
||||
|
||||
The speech-To-text `stt` task transcribes a given piece of audio, by taking an input encoded audio, and outputting the text transcription.
|
||||
|
@ -274,11 +285,13 @@ This task will follow a reverse sequence of `<audio><language><RVQ level><output
|
|||
The model can be prompted in creative ways to yield some interesting behaviors:
|
||||
* prompting without an input audio prompt will have the model generate a random voice ~~at the "cost" of some unintelligible utterance at the beginning of the output response (despite doing no promptless training)~~.
|
||||
* classifier-free-guidance-aware training does fix this, but this property emerges without it.
|
||||
* the AR is much better with this property, as the `NAR-len` gets crusty sometimes.
|
||||
* the AR is much better with this property, as the `NAR-len` gets crusty sometimes as it will keep demasking on crust.
|
||||
* prompting with an input text prompt being the transcription of the input audio prompt will have the response follow very closely to the input prompt (despite not doing input=output training).
|
||||
* this should allow for easy transcription editing without much fuss.
|
||||
* the `NAR-len` greatly exhibits this property, although it sometimes does keep any noise in the background.
|
||||
* extra care is required when doing this, as some checkpoints of the model will degrade completely the moment the prompt can't be directly referenced.
|
||||
* training without a language token will have the model derive the target language/accent from the phoneme sequence itself (it is a language model after all)
|
||||
* voice conversion is *possible* through demasking with the source prompt as the mask, but the current inferencing mechanism yields crust at the end of the output
|
||||
|
||||
# `models/*`
|
||||
|
||||
|
@ -288,7 +301,7 @@ This folder contains scripts relating to models and code for VALL-E use, from th
|
|||
|
||||
This script implements Low-Ranking Adapters, to allow for cheaper and easier finetuning of existing modules.
|
||||
|
||||
At the moment, two approaches are offered, through replacing `nn.Linear` outright, or parameterizing a `nn.Liner`. The latter is used by default(?).
|
||||
At the moment, two approaches are offered, through replacing `nn.Linear` outright, or parameterizing a `nn.Linear`. The latter is used by default(?).
|
||||
|
||||
## `models/base.py`
|
||||
|
||||
|
@ -303,6 +316,8 @@ This script implements the core underlying model for VALL-E. This handle:
|
|||
|
||||
This script aims to implement everything as required per VALL-E agnostically, to allow for different implementations to contain little extra code.
|
||||
|
||||
A very naive implementation of using the model can be found under the `__main__` invocation.
|
||||
|
||||
## `models/ar_nar.py`
|
||||
|
||||
This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively, if requested.
|
||||
|
@ -312,14 +327,6 @@ For training, this model handles preparing the batch provided through the datalo
|
|||
|
||||
For inferencing, this will dynamically inference depending on the arguments provided.
|
||||
|
||||
## `models/experimental.py`
|
||||
|
||||
This script implements VALL-E as a mostly-HuggingFace compatible model, where it handles processing tokens as a uniform sequence of IDs.
|
||||
|
||||
This mostly serves as an experiment to see what is required to do so, for possible future implementations requiring just `llama.cpp` and `encodec.cpp`, and to provide a pure HF-compatible implementation.
|
||||
|
||||
Use of this is governed through `cfg.model.experimental.hf = True`
|
||||
|
||||
## `models/arch/*`
|
||||
|
||||
This folder contains scripts, I've either written myself or properly attributed to, that provide or modify existing modules of a given model.
|
||||
|
@ -394,7 +401,7 @@ If I rememer right, it just simply provides gradient checkpointing.
|
|||
|
||||
### `models/arch/mixtral.py`
|
||||
|
||||
Like `llama.py`, this provides modifications to Mixtral through `transformers`.
|
||||
Like `llama.py`, this provides modifications to Mixtral through `transformers`. However, most of the niceties from `llama.py` are not available here as it's not the core backend.
|
||||
|
||||
Primarily, this is to address a bug with batch sizes > 1, and to use a different attention mechanism.
|
||||
* to-do: this is out of date from `llama.py`'s modified attention class.
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
ifeq ($(PREFIX),)
|
||||
PREFIX := /usr/local
|
||||
endif
|
||||
|
||||
CXX = g++
|
||||
|
||||
INCS += -I./include
|
||||
LIBS += -L./libs
|
||||
LIBS += -L./lib
|
||||
|
||||
LINKS += -lggml -lggml-base -lllama -lencodec -lespeak-ng
|
||||
FLAGS += -march=native -O3
|
||||
FLAGS += -march=native -O3 -DVALL_E_EXPORTS
|
||||
|
||||
SRCS := $(shell find ./ -name "*.cpp")
|
||||
OBJS += $(patsubst %.cpp,%.o,$(SRCS))
|
||||
|
||||
TARGET = vall_e
|
||||
TARGET_LIB = lib$(TARGET).so
|
||||
TARGET_HEADER = $(TARGET).h
|
||||
|
||||
|
||||
%.o: %.cpp
|
||||
$(CXX) $(FLAGS) $(INCS) -c $< -o $@
|
||||
|
@ -17,6 +24,19 @@ TARGET = vall_e
|
|||
$(TARGET): $(OBJS)
|
||||
$(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET)
|
||||
|
||||
$(TARGET_LIB): $(OBJS)
|
||||
$(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET_LIB)
|
||||
|
||||
all: $(TARGET_LIB) $(TARGET)
|
||||
|
||||
lib: $(TARGET_LIB)
|
||||
|
||||
install:
|
||||
cp $(TARGET) $(PREFIX)/bin/$(TARGET)
|
||||
-cp $(TARGET_LIB) $(PREFIX)/lib/$(TARGET_LIB)
|
||||
cp $(TARGET_HEADER) $(PREFIX)/include/$(TARGET_HEADER)
|
||||
|
||||
clean:
|
||||
@-rm -f $(OBJS)
|
||||
@-rm -f $(TARGET)
|
||||
@-rm -f $(TARGET)
|
||||
@-rm -f $(TARGET_LIB)
|
|
@ -2,15 +2,15 @@
|
|||
|
||||
This is an implementation that makes use of [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [encodec.cpp](https://github.com/PABannier/encodec.cpp).
|
||||
|
||||
At the moment it's ***very*** work in progress.
|
||||
|
||||
Model weights can be found at [`ecker/vall-e@gguf`](https://huggingface.co/ecker/vall-e/tree/gguf).
|
||||
Model weights can:
|
||||
* be found at [`ecker/vall-e@gguf`](https://huggingface.co/ecker/vall-e/tree/gguf)
|
||||
* converted with `vall_e.export --yaml=./model_path/config.yaml --hf`, then running `python3 /path/to/your/llama.cpp/convert_hf_to_gguf ./model_path/hf/`
|
||||
|
||||
## Build
|
||||
|
||||
Populate `./include/` with the `ggml`, `llama.cpp`, and `encodec.cpp` headers.
|
||||
|
||||
Populate `./libs/` with the compiled libraries of `llama.cpp`, `encodec.cpp`, and `espeak-ng`.
|
||||
Populate `./lib/` with the compiled libraries of `llama.cpp`, `encodec.cpp`, and `espeak-ng` (if not already in your `LD_LIBRARY_PATH`).
|
||||
|
||||
Run `make`.
|
||||
|
||||
|
@ -23,7 +23,7 @@ Run `make`.
|
|||
## To-Do
|
||||
|
||||
* [x] converted model to GGUF
|
||||
* [ ] convert it without modifying any of the existing code, as the tokenizer requires some care
|
||||
* [x] convert it without modifying any of the existing code, as the tokenizer requires some care
|
||||
* [x] basic framework
|
||||
* [x] load the quantized model
|
||||
* [x] orchestrate the required embeddings
|
||||
|
@ -45,6 +45,11 @@ Run `make`.
|
|||
* [x] a functional CLI
|
||||
* [x] actually make it work
|
||||
* [x] clean up to make the code usable elsewhere
|
||||
* [x] configured to allow for being used as a lib
|
||||
* (I do need to validate this in my engine project, but that's in MSYS2)
|
||||
* [ ] feature parity with the PyTorch version
|
||||
* [ ] vocos
|
||||
* [ ] additional tasks (`stt`, `ns`, `sr`, samplers)
|
||||
* [ ] additional tasks
|
||||
* [ ] `stt`
|
||||
* [x] `ns` / `sr`
|
||||
* [ ] samplers
|
93
vall_e.cpp/vall_e-impl.h
Normal file
93
vall_e.cpp/vall_e-impl.h
Normal file
|
@ -0,0 +1,93 @@
|
|||
#pragma once
|
||||
|
||||
// stores all the backend stuff
|
||||
|
||||
// external deps
|
||||
#include <llama.h>
|
||||
#include <encodec.h>
|
||||
#include <dr_wav.h>
|
||||
#include <espeak-ng/speak_lib.h>
|
||||
|
||||
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
|
||||
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
|
||||
|
||||
#if !LLAMA_CPP_EXTENDED
|
||||
#include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
|
||||
#endif
|
||||
|
||||
// to-do: clean up spaghetti enums
|
||||
const int EMBEDDING_MODE_PROM = 0;
|
||||
const int EMBEDDING_MODE_RESP_AR_NAR = 1;
|
||||
const int EMBEDDING_MODE_RESP_NAR_LEN = 2;
|
||||
|
||||
const int INFERENCE_MODE_LEN = 0;
|
||||
const int INFERENCE_MODE_AR = 1;
|
||||
const int INFERENCE_MODE_NAR_DEMASK = 2;
|
||||
const int INFERENCE_MODE_NAR = 3;
|
||||
|
||||
// stores metadata for inputs/outputs
|
||||
struct io_t {
|
||||
std::string name;
|
||||
uint32_t start;
|
||||
uint32_t end;
|
||||
int32_t head_idx = -1;
|
||||
|
||||
int32_t n_embd = 0;
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
std::vector<float> embds = {};
|
||||
ggml_tensor* head = NULL;
|
||||
};
|
||||
|
||||
// stores the mappings between tokens, input embeddings, and output heads
|
||||
struct io_map_t {
|
||||
// model's original params
|
||||
int32_t n_embd = 0;
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
// mapping
|
||||
std::unordered_map<std::string, io_t> io = {};
|
||||
// context to store slices
|
||||
ggml_context* ctx = NULL;
|
||||
};
|
||||
// used for top-k (mainly for demasking)
|
||||
struct score_t {
|
||||
int32_t idx;
|
||||
float value;
|
||||
|
||||
bool operator<( const score_t& that ) const { return this->value < that.value; }
|
||||
};
|
||||
// handles storing metadata for token merges
|
||||
struct merge_entry_t {
|
||||
std::u32string pre;
|
||||
std::u32string post;
|
||||
std::u32string resolved;
|
||||
|
||||
token_t pre_token;
|
||||
token_t post_token;
|
||||
token_t resolved_token;
|
||||
};
|
||||
|
||||
// helper tensor functions
|
||||
std::vector<float> read_2d_tensor( struct ggml_tensor* tensor );
|
||||
//ggml_tensor* view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket
|
||||
ggml_tensor* view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 );
|
||||
void print_tokens( const std::vector<token_t>& tokens, const std::string& prefix = "Tokens: " );
|
||||
|
||||
std::vector<std::vector<float>> map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds );
|
||||
std::vector<std::vector<float>> sum_embeddings( const vall_e_audio_codes_t& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
|
||||
std::vector<float> soft_max( int n_logits, const float* logits );
|
||||
|
||||
// batch and inferencing
|
||||
void batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} );
|
||||
void fill_batch( llama_batch& batch, vall_e_inputs_t& input, io_map_t& inputs_map, int mode );
|
||||
std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& input, int max_tokens, int mode, bool verbose = true );
|
||||
|
||||
// (handles text)
|
||||
std::vector<token_t> phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language = "auto" );
|
||||
|
||||
// model-accessing helpers
|
||||
const io_t& vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name );
|
||||
const float* vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name );
|
||||
int32_t vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name );
|
||||
void vall_e_inputs_map_init( io_map_t&, llama_model* model );
|
|
@ -1,5 +1,6 @@
|
|||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "vall_e.h"
|
||||
#include "vall_e-impl.h" // stores everything that isn't necessary for exporting
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
|
@ -39,19 +40,43 @@ io_t io_ranges[] = {
|
|||
{ "resps|NAR:0:0", 16677, 17702, 8 },
|
||||
};
|
||||
|
||||
// stored here because I tokenize the merges
|
||||
// I can't be assed to figure out the tokenizer right now
|
||||
// lang map
|
||||
std::unordered_map<std::string, token_t> lang_map = {
|
||||
{ "en", 0 },
|
||||
{ "ja", 1 },
|
||||
{ "de", 2 },
|
||||
{ "fr", 3 },
|
||||
{ "zh", 4 },
|
||||
{ "ko", 5 },
|
||||
};
|
||||
std::unordered_map<std::string, token_t> task_map = {
|
||||
{ "tts", 0 },
|
||||
{ "tts-c", 1 },
|
||||
{ "ns", 2 },
|
||||
{ "sr", 3 },
|
||||
{ "tse", 4 },
|
||||
{ "soe", 5 },
|
||||
{ "mask", 6 },
|
||||
{ "eoe", 7 },
|
||||
{ "stt", 8 },
|
||||
|
||||
{ "len", 0 },
|
||||
{ "nse", 6 },
|
||||
{ "cse", 6 },
|
||||
};
|
||||
|
||||
// u32string because encoding agony
|
||||
std::unordered_map<std::u32string, token_t> vocab = {
|
||||
{U"<unk>",0},{U"<bos>",1},{U"</eos>",2},{U"<mask>",3},{U" ",4},{U"ᵝ",4},{U"!",5},{U"\"",6},{U"(",7},{U"{",7},{U"[",7},{U")",8},{U"}",8},{U"]",8},{U",",9},{U"-",10},{U".",11},{U"1",211},{U"—",10},{U"“",6},{U"”",81},{U"ˇ",6},{U"ˉ",12},{U"ˊ",79},{U"ˋ",80},{U"_",81},{U":",13},{U";",14},{U"?",15},{U"a",16},{U"ä",16},{U"ɒ",16},{U"b",17},{U"c",18},{U"d",19},{U"e",20},{U"f",21},{U"h",22},{U"i",23},{U"ĩ",23},{U"j",24},{U"k",25},{U"l",26},{U"m",27},{U"n",28},{U"ɴ",28},{U"ɲ",28},{U"o",29},{U"̞",29},{U"p",30},{U"ɸ",30},{U"q",31},{U"r",32},{U"ɽ",32},{U"ʁ",32},{U"s",33},{U"t",34},{U"u",35},{U"ø",35},{U"œ",35},{U"y",35},{U"ɣ",35},{U"ũ",35},{U"v",36},{U"w",37},{U"ʍ",37},{U"x",38},{U"z",39},{U"¡",40},{U"«",41},{U"»",42},{U"¿",43},{U"æ",44},{U"ç",45},{U"ð",46},{U"ŋ",47},{U"ɐ",48},{U"ɑ",49},{U"ɔ",50},{U"ɕ",51},{U"ə",52},{U"ɚ",53},{U"ɛ",54},{U"ɜ",55},{U"ɟ",56},{U"ɡ",57},{U"ɪ",58},{U"ɬ",59},{U"ɯ",60},{U"ɹ",61},{U"ɾ",62},{U"ʃ",63},{U"ʈ",64},{U"ʊ",65},{U"ʋ",66},{U"ʌ",67},{U"ʑ",68},{U"ʒ",69},{U"ʔ",70},{U"ʲ",71},{U"ˈ",72},{U"ˌ",73},{U"ː",74},{U"̃",75},{U"̩",76},{U"θ",77},{U"ᵻ",78},{U"…",82},{U"ˈɛ",83},{U"iː",84},{U"aɪ",85},{U"nd",86},{U"ˈɪ",87},{U"eɪ",88},{U"ˈæ",89},{U"ðə",90},{U"oʊ",91},{U"ɑː",92},{U"ˈeɪ",93},{U"ən",94},{U"uː",95},{U"ˈʌ",96},{U"ˈaɪ",97},{U"st",98},{U"ˈɔ",99},{U"ˈoʊ",100},{U"ˈiː",101},{U"ˈɑː",102},{U"ænd",103},{U"ːɹ",104},{U"ɪŋ",105},{U"ɜː",106},{U"ɪn",107},{U"tə",108},{U"ʌv",109},{U"aʊ",110},{U"əl",111},{U"ˈuː",112},{U"tʃ",113},{U"ɪz",114},{U"ˈɜː",115},{U"ˌʌ",116},{U"æt",117},{U"dʒ",118},{U"ˈɔː",119},{U"ɪt",120},{U"ˈaʊ",121},{U"ɚɹ",122},{U"ˈɛn",123},{U"wʌ",124},{U"li",125},{U"hiː",126},{U"ˌɛ",127},{U"wɪ",128},{U"wʌz",129},{U"ðæt",130},{U"juː",131},{U"oːɹ",132},{U"ðɪ",133},{U"sˈɛ",134},{U"ˌɪ",135},{U"ˈɑːɹ",136},{U"nt",137},{U"ˈʊ",138},{U"ənt",139},{U"hɪz",140},{U"ˌɑː",141},{U"hæ",142},{U"ɔːɹ",143},{U"ˈɛɹ",144},{U"wɪð",145},{U"ᵻd",146},{U"ˈoːɹ",147},{U"pɹ",148},{U"ˈɔːl",149},{U"mˌ",150},{U"ʃən",151},{U"kt",152},{U"ˌoʊ",153},{U"ˈɔːɹ",154},{U"fɹ",155},{U"æz",156},{U"ˌʌt",157},{U"ʃiː",158},{U"ˈɛl",159},{U"ˌaʊ",160},{U"ˈʌn",161},{U"əs",162},{U"hɜː",163},{U"lˈaɪ",164},{U"ˈæn",165},{U"ˈɪɹ",166},{U"ʊd",167},{U"ɹᵻ",168},{U"ld",169},{U"bˌʌt",170},{U"ks",171},{U"nˈoʊ",172},{U"hæd",173},{U"ɾɚ",174},{U"ɛɹ",175},{U"ˈɪŋ",176},{U"ɡɹ",177},{U"nˌɑː",178},{U"ɔn",179},{U"vɚ",180},{U"maɪ",181},{U"fɔːɹ",182},{U"ðɚ",183},{U"tʊ",184},{U"ðɛɹ",185},{U"nˌɑːt",186},{U"ˈʌm",187},{U"tɹ",188},{U"sˈiː",189},{U"ʌvðə",190},{U"mˈɪ",191},{U"hˈæ",192},{U"ˌɪm",193},{U"lˈeɪ",194},{U"ɪk",195},{U"sp",196},{U"hˌɪm",197},{U"ɐn",198},{U"ðeɪ",199},{U"lˈɪ",200},{U"ɾi",201},{U"lˈɛ",202},{U"bɹ",203},{U"kɹ",204},{U"lˈæ",205},{U"ˈɪl",206},{U"jˈuː",207},{U"ʌm",208},{U"mˌiː",209},{U"bᵻ",210},{U"wˈʌn",211},{U"ˌɪn",212},{U"ˈɪn",213},{U"ˈoʊn",214},{U"sˈɛd",215},{U"biː",216},{U"ˈɛd",217},{U"ˈaɪt",218},{U"baɪ",219},{U"fɹʌm",220},{U"ɪs",221},{U"ɚz",222},{U"ðɪs",223},{U"əns",224},{U"bəl",225},{U"ɪf",226},{U"ɪnðə",227},{U"əm",228},{U"ᵻz",229},{U"ˌuː",230},{U"wˈeɪ",231},{U"ft",232},{U"wiː",233},{U"stɹ",234},{U"lˈiː",235},{U"iːz",236},{U"pt",237},{U"jʊ",238},{U"ɚd",239},{U"ˌaɪ",240},{U"kw",241},{U"ˌɔn",242},{U"ˈaɪd",243},{U"ɪm",244},{U"ˈʌst",245},{U"ˈoʊld",246},{U"ts",247},{U"ˌɪtʃ",248},{U"sˌoʊ",249},{U"dˈɪ",250},{U"ɑːɹ",251},{U"hɐ",252},{U"sˈeɪ",253},{U"ɾᵻd",254},{U"wˌɪtʃ",255},
|
||||
};
|
||||
|
||||
// cringe list of merges to later process and fill out the map for referencing merges
|
||||
std::vector<merge_entry_t> vocab_merges = {
|
||||
{U"ˈ", U"ɛ"},{U"i", U"ː"},{U"a", U"ɪ"},{U"n", U"d"},{U"ˈ", U"ɪ"},{U"e", U"ɪ"},{U"ˈ", U"æ"},{U"ð", U"ə"},{U"o", U"ʊ"},{U"ɑ", U"ː"},{U"ˈ", U"eɪ"},{U"ə", U"n"},{U"u", U"ː"},{U"ˈ", U"ʌ"},{U"ˈ", U"aɪ"},{U"s", U"t"},{U"ˈ", U"ɔ"},{U"ˈ", U"oʊ"},{U"ˈ", U"iː"},{U"ˈ", U"ɑː"},{U"æ", U"nd"},{U"ː", U"ɹ"},{U"ɪ", U"ŋ"},{U"ɜ", U"ː"},{U"ɪ", U"n"},{U"t", U"ə"},{U"ʌ", U"v"},{U"a", U"ʊ"},{U"ə", U"l"},{U"ˈ", U"uː"},{U"t", U"ʃ"},{U"ɪ", U"z"},{U"ˈ", U"ɜː"},{U"ˌ", U"ʌ"},{U"æ", U"t"},{U"d", U"ʒ"},{U"ˈɔ", U"ː"},{U"ɪ", U"t"},{U"ˈ", U"aʊ"},{U"ɚ", U"ɹ"},{U"ˈɛ", U"n"},{U"w", U"ʌ"},{U"l", U"i"},{U"h", U"iː"},{U"ˌ", U"ɛ"},{U"w", U"ɪ"},{U"wʌ", U"z"},{U"ð", U"æt"},{U"j", U"uː"},{U"o", U"ːɹ"},{U"ð", U"ɪ"},{U"s", U"ˈɛ"},{U"ˌ", U"ɪ"},{U"ˈɑː", U"ɹ"},{U"n", U"t"},{U"ˈ", U"ʊ"},{U"ən", U"t"},{U"h", U"ɪz"},{U"ˌ", U"ɑː"},{U"h", U"æ"},{U"ɔ", U"ːɹ"},{U"ˈɛ", U"ɹ"},{U"wɪ", U"ð"},{U"ᵻ", U"d"},{U"ˈ", U"oːɹ"},{U"p", U"ɹ"},{U"ˈɔː", U"l"},{U"m", U"ˌ"},{U"ʃ", U"ən"},{U"k", U"t"},{U"ˌ", U"oʊ"},{U"ˈɔ", U"ːɹ"},{U"f", U"ɹ"},{U"æ", U"z"},{U"ˌʌ", U"t"},{U"ʃ", U"iː"},{U"ˈɛ", U"l"},{U"ˌ", U"aʊ"},{U"ˈʌ", U"n"},{U"ə", U"s"},{U"h", U"ɜː"},{U"l", U"ˈaɪ"},{U"ˈæ", U"n"},{U"ˈɪ", U"ɹ"},{U"ʊ", U"d"},{U"ɹ", U"ᵻ"},{U"l", U"d"},{U"b", U"ˌʌt"},{U"k", U"s"},{U"n", U"ˈoʊ"},{U"hæ", U"d"},{U"ɾ", U"ɚ"},{U"ɛ", U"ɹ"},{U"ˈɪ", U"ŋ"},{U"ɡ", U"ɹ"},{U"n", U"ˌɑː"},{U"ɔ", U"n"},{U"v", U"ɚ"},{U"m", U"aɪ"},{U"f", U"ɔːɹ"},{U"ð", U"ɚ"},{U"t", U"ʊ"},{U"ð", U"ɛɹ"},{U"nˌɑː", U"t"},{U"ˈʌ", U"m"},{U"t", U"ɹ"},{U"s", U"ˈiː"},{U"ʌv", U"ðə"},{U"m", U"ˈɪ"},{U"h", U"ˈæ"},{U"ˌɪ", U"m"},{U"l", U"ˈeɪ"},{U"ɪ", U"k"},{U"s", U"p"},{U"h", U"ˌɪm"},{U"ɐ", U"n"},{U"ð", U"eɪ"},{U"l", U"ˈɪ"},{U"ɾ", U"i"},{U"l", U"ˈɛ"},{U"b", U"ɹ"},{U"k", U"ɹ"},{U"l", U"ˈæ"},{U"ˈɪ", U"l"},{U"j", U"ˈuː"},{U"ʌ", U"m"},{U"mˌ", U"iː"},{U"b", U"ᵻ"},{U"w", U"ˈʌn"},{U"ˌ", U"ɪn"},{U"ˈɪ", U"n"},{U"ˈoʊ", U"n"},{U"sˈɛ", U"d"},{U"b", U"iː"},{U"ˈɛ", U"d"},{U"ˈaɪ", U"t"},{U"b", U"aɪ"},{U"fɹ", U"ʌm"},{U"ɪ", U"s"},{U"ɚ", U"z"},{U"ðɪ", U"s"},{U"ən", U"s"},{U"b", U"əl"},{U"ɪ", U"f"},{U"ɪn", U"ðə"},{U"ə", U"m"},{U"ᵻ", U"z"},{U"ˌ", U"uː"},{U"w", U"ˈeɪ"},{U"f", U"t"},{U"w", U"iː"},{U"st", U"ɹ"},{U"l", U"ˈiː"},{U"iː", U"z"},{U"p", U"t"},{U"j", U"ʊ"},{U"ɚ", U"d"},{U"ˌ", U"aɪ"},{U"k", U"w"},{U"ˌ", U"ɔn"},{U"ˈaɪ", U"d"},{U"ɪ", U"m"},{U"ˈʌ", U"st"},{U"ˈoʊ", U"ld"},{U"t", U"s"},{U"ˌɪ", U"tʃ"},{U"s", U"ˌoʊ"},{U"d", U"ˈɪ"},{U"ɑː", U"ɹ"},{U"h", U"ɐ"},{U"s", U"ˈeɪ"},{U"ɾ", U"ᵻd"},{U"w", U"ˌɪtʃ"},
|
||||
};
|
||||
// merge map to reference when tokenizing text
|
||||
std::unordered_map<std::string, merge_entry_t> vocab_merge_map = {};
|
||||
|
||||
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
|
||||
std::vector<float> read_2d_tensor( struct ggml_tensor* tensor ) {
|
||||
size_t size = tensor->ne[0] * tensor->ne[1];
|
||||
std::vector<float> res( size );
|
||||
|
||||
|
@ -65,7 +90,7 @@ std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
|
|||
return res;
|
||||
}
|
||||
/*
|
||||
ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
|
||||
ggml_tensor* view_2d_tensor( struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
|
||||
// to-do: implement other dim
|
||||
if ( start < 0 ) start = tensor->ne[1] + start;
|
||||
if ( end < 0 ) end = tensor->ne[1] + end;
|
||||
|
@ -86,7 +111,7 @@ ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_tensor* tensor, int32_t star
|
|||
return res;
|
||||
}
|
||||
*/
|
||||
ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_context* ctx, struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
|
||||
ggml_tensor* view_2d_tensor( struct ggml_context* ctx, struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
|
||||
// to-do: implement other dim
|
||||
if ( start < 0 ) start = tensor->ne[1] + start;
|
||||
if ( end < 0 ) end = tensor->ne[1] + end;
|
||||
|
@ -96,7 +121,7 @@ ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_context* ctx, struct ggml_te
|
|||
return res;
|
||||
}
|
||||
|
||||
void VALL_E_API print_tokens( const std::vector<token_t>& tokens, const std::string& prefix ) {
|
||||
void print_tokens( const std::vector<token_t>& tokens, const std::string& prefix ) {
|
||||
printf("%s[", prefix.c_str());
|
||||
for ( auto i = 0; i < tokens.size(); ++i ) {
|
||||
printf("%i%s", tokens[i], i + 1 < tokens.size() ? ", " : "");
|
||||
|
@ -104,18 +129,18 @@ void VALL_E_API print_tokens( const std::vector<token_t>& tokens, const std::str
|
|||
printf("]\n");
|
||||
}
|
||||
|
||||
const io_t& VALL_E_API vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) {
|
||||
const io_t& vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) {
|
||||
return io_map.io[name];
|
||||
}
|
||||
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& io_map, const std::string& name ) {
|
||||
const float* vall_e_inputs_map_get_embeddings_p( io_map_t& io_map, const std::string& name ) {
|
||||
return io_map.io[name].embds.data();
|
||||
}
|
||||
|
||||
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& io_map, const std::string& name ) {
|
||||
int32_t vall_e_inputs_map_get_classifier_idx( io_map_t& io_map, const std::string& name ) {
|
||||
return io_map.io[name].head_idx;
|
||||
}
|
||||
|
||||
void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
|
||||
void vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
|
||||
auto n_embd = llama_n_embd( model );
|
||||
auto n_vocab = llama_n_vocab( model );
|
||||
|
||||
|
@ -187,7 +212,7 @@ void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
|
|||
}
|
||||
|
||||
// maps embeddings easily
|
||||
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds ) {
|
||||
std::vector<std::vector<float>> map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds ) {
|
||||
std::vector<std::vector<float>> embedded( tokens.size() );
|
||||
for ( auto i = 0; i < tokens.size(); ++i ) {
|
||||
embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) );
|
||||
|
@ -197,7 +222,7 @@ std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<tok
|
|||
|
||||
// handles adding either a token OR the embedding of that token into the batch
|
||||
// this really, really helps avoid needing to abuse the tokenizer
|
||||
void VALL_E_API batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids ) {
|
||||
void batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids ) {
|
||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
||||
|
||||
// insert raw embedding instead
|
||||
|
@ -219,7 +244,7 @@ void VALL_E_API batch_add( llama_batch& batch, token_t id, int n_embd, const flo
|
|||
batch.n_tokens++;
|
||||
}
|
||||
// reads a waveform from disk
|
||||
std::vector<float> VALL_E_API read_audio_from_disk( const std::string& path ) {
|
||||
std::vector<float> read_audio_from_disk( const std::string& path ) {
|
||||
std::vector<float> res;
|
||||
|
||||
uint32_t channels;
|
||||
|
@ -248,7 +273,7 @@ std::vector<float> VALL_E_API read_audio_from_disk( const std::string& path ) {
|
|||
return res;
|
||||
}
|
||||
// writes a waveform to disk
|
||||
void VALL_E_API write_audio_to_disk( const std::vector<float>& wavform, const std::string& path ) {
|
||||
void write_audio_to_disk( const std::vector<float>& wavform, const std::string& path ) {
|
||||
drwav_data_format format;
|
||||
format.bitsPerSample = 32;
|
||||
format.sampleRate = 24000;
|
||||
|
@ -264,7 +289,7 @@ void VALL_E_API write_audio_to_disk( const std::vector<float>& wavform, const st
|
|||
fprintf(stderr, "%s: Number of frames written = %lld.\n", __func__, frames);
|
||||
}
|
||||
// reads a waveform from disk then encodes it
|
||||
std::vector<std::vector<int32_t>> VALL_E_API encode_audio( struct encodec_context* ectx, const std::vector<float>& wavform ) {
|
||||
std::vector<std::vector<int32_t>> encode_audio( struct encodec_context* ectx, const std::vector<float>& wavform ) {
|
||||
// compress audio
|
||||
if (!encodec_compress_audio(ectx, wavform.data(), wavform.size(), 1)) {
|
||||
fprintf(stderr, "%s: error during compression \n", __func__);
|
||||
|
@ -285,7 +310,7 @@ std::vector<std::vector<int32_t>> VALL_E_API encode_audio( struct encodec_contex
|
|||
return res;
|
||||
}
|
||||
// decodes a 2D codebook into a waveform
|
||||
std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes ) {
|
||||
std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes ) {
|
||||
int n_codebooks = codes.size();
|
||||
int n_frames = codes[0].size();
|
||||
|
||||
|
@ -310,7 +335,7 @@ std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const
|
|||
}
|
||||
|
||||
// sums embeddings over a 2D "tensor"
|
||||
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std::vector<token_t>>& inputs, int n_embd, int rvq_l, const float** embds, int mode ) {
|
||||
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<token_t>>& inputs, int n_embd, int rvq_l, const float** embds, int mode ) {
|
||||
auto n_tokens = inputs[0].size();
|
||||
|
||||
std::vector<std::vector<float>> res( n_tokens, std::vector<float>( n_embd, 0.0 ) );
|
||||
|
@ -336,7 +361,7 @@ std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std
|
|||
return res;
|
||||
}
|
||||
|
||||
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
|
||||
std::vector<float> soft_max( int n_logits, const float* logits ) {
|
||||
std::vector<float> res( n_logits, 0.0f );
|
||||
std::vector<float> expd( n_logits, 0.0f );
|
||||
float denom = 0.0f;
|
||||
|
@ -353,7 +378,7 @@ std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
|
|||
return res;
|
||||
}
|
||||
|
||||
std::vector<float> VALL_E_API log_soft_max( int n_logits, const float* logits ) {
|
||||
std::vector<float> log_soft_max( int n_logits, const float* logits ) {
|
||||
std::vector<float> res( n_logits, 0.0f );
|
||||
float denom = 0.0f;
|
||||
|
||||
|
@ -368,7 +393,7 @@ std::vector<float> VALL_E_API log_soft_max( int n_logits, const float* logits )
|
|||
return res;
|
||||
}
|
||||
|
||||
void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_t& io_map, int mode ) {
|
||||
void fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_t& io_map, int mode ) {
|
||||
// keeps track of the position for each sequence
|
||||
size_t pos = 0;
|
||||
auto n_embd = io_map.n_embd;
|
||||
|
@ -402,18 +427,25 @@ void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_
|
|||
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:0:0"),
|
||||
};
|
||||
|
||||
token_t lang_token = lang_map[inputs.lang];
|
||||
token_t task_token = task_map[inputs.task];
|
||||
|
||||
// insert text tokens
|
||||
for ( auto& id : inputs.phn ) batch_add( batch, id, n_embd, text_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
|
||||
pos = 0;
|
||||
// insert lang token
|
||||
batch_add( batch, inputs.lang, n_embd, lang_embds, pos++, false );
|
||||
batch_add( batch, lang_token, n_embd, lang_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
|
||||
pos = 0;
|
||||
// insert rvq level token
|
||||
batch_add( batch, inputs.rvq_l, n_embd, rvq_l_embds, pos++, false );
|
||||
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
|
||||
pos = 0;
|
||||
// input task token if needed
|
||||
if ( task_token > 0 ) {
|
||||
batch_add( batch, task_token, n_embd, task_embds, pos++, false );
|
||||
}
|
||||
// insert prom tokens
|
||||
auto summed_proms_embds = sum_embeddings( inputs.prom, n_embd, inputs.rvq_l, prom_embds );
|
||||
for ( auto i = 0; i < summed_proms_embds.size(); ++i ) {
|
||||
|
@ -439,13 +471,13 @@ void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_
|
|||
}
|
||||
|
||||
// generation code, should handle all modalities easily
|
||||
std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_tokens, int mode, bool verbose ) {
|
||||
std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_tokens, int mode, bool verbose ) {
|
||||
bool causal = true; // sample autoregressively or not
|
||||
int n_outputs = 0; // number of output tokens to expect
|
||||
|
||||
// create batch (targetting embeddings instead of tokens)
|
||||
llama_batch batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map.n_embd, ctx->params.ctx_size );
|
||||
fill_batch( batch, inputs, ctx->io_map, mode );
|
||||
llama_batch batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map->n_embd, ctx->params.ctx_size );
|
||||
fill_batch( batch, inputs, *ctx->io_map, mode );
|
||||
|
||||
// determine how many outputs we need
|
||||
for ( auto i = 0; i < batch.n_tokens; ++i ) {
|
||||
|
@ -485,7 +517,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
|||
embd_name = "resps|NAR:0:0";
|
||||
}
|
||||
|
||||
auto& io = vall_e_inputs_map_get(ctx->io_map, embd_name);
|
||||
auto& io = vall_e_inputs_map_get(*ctx->io_map, embd_name);
|
||||
const float* embds = io.embds.data();
|
||||
|
||||
int32_t n_embd = io.n_embd;
|
||||
|
@ -535,7 +567,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
|||
// store token
|
||||
output_tokens.emplace_back(t);
|
||||
// update batch with token
|
||||
batch_add( batch, t, ctx->io_map.n_embd, embds, output_tokens.size(), true );
|
||||
batch_add( batch, t, ctx->io_map->n_embd, embds, output_tokens.size(), true );
|
||||
|
||||
if ( verbose ) print_tokens( output_tokens );
|
||||
}
|
||||
|
@ -560,7 +592,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
|||
null_input.phn = {1, 2}; // <bos></eos>
|
||||
null_input.resp.resize(1);
|
||||
|
||||
llama_batch null_batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map.n_embd, ctx->params.ctx_size );
|
||||
llama_batch null_batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map->n_embd, ctx->params.ctx_size );
|
||||
|
||||
// token scores to reference for masking
|
||||
std::vector<float> scores(n_outputs, 1.0);
|
||||
|
@ -607,11 +639,11 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
|||
// to-do: only update the embeddings instead
|
||||
batch.n_tokens = 0;
|
||||
inputs.resp[0] = output_tokens;
|
||||
fill_batch( batch, inputs, ctx->io_map, mode );
|
||||
fill_batch( batch, inputs, *ctx->io_map, mode );
|
||||
// update null batch
|
||||
null_input.resp[0] = output_tokens;
|
||||
null_batch.n_tokens = 0;
|
||||
fill_batch( null_batch, inputs, ctx->io_map, mode );
|
||||
fill_batch( null_batch, inputs, *ctx->io_map, mode );
|
||||
|
||||
// cfg decode
|
||||
if ( llama_decode(ctx->llama.ctx, null_batch) ) {
|
||||
|
@ -717,7 +749,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
|
|||
return output_tokens;
|
||||
}
|
||||
|
||||
std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) {
|
||||
std::vector<token_t> phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) {
|
||||
std::vector<token_t> tokens;
|
||||
|
||||
// phonemize text
|
||||
|
@ -767,6 +799,7 @@ std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::str
|
|||
}
|
||||
tokens.emplace_back(2);
|
||||
|
||||
if ( ctx->params.verbose ) print_tokens( tokens, "Phonemes: " );
|
||||
|
||||
/*
|
||||
// to-do: fix terminate called after throwing an instance of 'std::out_of_range'
|
||||
|
@ -782,7 +815,7 @@ std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::str
|
|||
return tokens;
|
||||
}
|
||||
|
||||
void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
|
||||
void vall_e_print_usage( char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
|
@ -803,6 +836,8 @@ void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params
|
|||
fprintf(stderr, " Input text prompt (default: %s)\n", args.text.c_str());
|
||||
fprintf(stderr, " -l TEXT, --language TEXT\n");
|
||||
fprintf(stderr, " Language for input text / output response (default: %s)\n", args.language.c_str());
|
||||
fprintf(stderr, " -ts TASK, --task TASK\n");
|
||||
fprintf(stderr, " Inferencing task (default: %s, accepts ['tts', 'stt', 'ns', 'sr'])\n", args.task);
|
||||
fprintf(stderr, " -mode MODE, --modality MODE\n");
|
||||
fprintf(stderr, " Modality for inferencing (default: %s, accepts ['ar+nar', 'nar-len'])\n", args.modality == MODALITY_NAR_LEN ? "nar-len" : "ar+nar");
|
||||
fprintf(stderr, " -ms N, --max-steps N\n");
|
||||
|
@ -815,7 +850,7 @@ void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params
|
|||
fprintf(stderr, " Output audio wav (default: %s)\n", args.output_path.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
|
||||
bool vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
|
||||
for ( int i = 1; i < argc; i++ ) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
|
@ -835,6 +870,8 @@ bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_
|
|||
args.text = argv[++i];
|
||||
} else if (arg == "-l" || arg == "--language") {
|
||||
args.language = argv[++i];
|
||||
} else if (arg == "-ts" || arg == "--task") {
|
||||
args.task = argv[++i];
|
||||
} else if (arg == "-mode" || arg == "--modality") {
|
||||
args.modality = argv[++i] == "ar+nar" ? MODALITY_AR_NAR : MODALITY_NAR_LEN;
|
||||
} else if (arg == "-ms" || arg == "--max-steps") {
|
||||
|
@ -859,8 +896,9 @@ bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_
|
|||
return true;
|
||||
}
|
||||
|
||||
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params ) {
|
||||
vall_e_context_t* vall_e_load( const vall_e_context_params_t& params ) {
|
||||
vall_e_context_t* ctx = new vall_e_context_t();
|
||||
ctx->io_map = new io_map_t();
|
||||
ctx->params = params;
|
||||
|
||||
// setup ggml
|
||||
|
@ -905,7 +943,7 @@ vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params
|
|||
espeak_Initialize(AUDIO_OUTPUT_SYNCHRONOUS, 0, NULL, 0);
|
||||
|
||||
// setup vall_e.cpp
|
||||
vall_e_inputs_map_init( ctx->io_map, ctx->llama.model );
|
||||
vall_e_inputs_map_init( *ctx->io_map, ctx->llama.model );
|
||||
|
||||
// setup vocab things
|
||||
for ( auto& entry : vocab_merges ) {
|
||||
|
@ -921,17 +959,15 @@ vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params
|
|||
|
||||
return ctx;
|
||||
}
|
||||
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& language ) {
|
||||
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& language, const std::string& task ) {
|
||||
// to-do: set members in initializer rather than in post
|
||||
vall_e_inputs_t inputs;
|
||||
|
||||
inputs.task = task;
|
||||
inputs.rvq_l = 0;
|
||||
inputs.phn = phonemize( ctx, text, language );
|
||||
inputs.prom = encode_audio( ctx->encodec.ctx, read_audio_from_disk( prompt_path ) );
|
||||
if ( language == "en" ) inputs.lang = 0;
|
||||
else if ( language == "ja" ) inputs.lang = 1;
|
||||
else if ( language == "de" ) inputs.lang = 2;
|
||||
else if ( language == "fr" ) inputs.lang = 3;
|
||||
else if ( language == "zh" ) inputs.lang = 4;
|
||||
else if ( language == "ko" ) inputs.lang = 5;
|
||||
inputs.lang = language;
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
@ -943,17 +979,20 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
|||
// inference len
|
||||
int len = 0;
|
||||
if ( !len ) {
|
||||
auto task = inputs.task;
|
||||
inputs.task = "len";
|
||||
output_tokens = generate( ctx, inputs, 5, INFERENCE_MODE_LEN, ctx->params.verbose );
|
||||
{
|
||||
// to-do: one liner this
|
||||
int digit = 1;
|
||||
for (auto it = output_tokens.rbegin(); it < output_tokens.rend(); ++it) {
|
||||
len += (*it) * digit;
|
||||
digit *= 10;
|
||||
}
|
||||
}
|
||||
// cap for now
|
||||
// cap duration
|
||||
if ( len <= 0 || len > max_duration ) len = max_duration;
|
||||
inputs.task = task;
|
||||
}
|
||||
// fill with mask tokens
|
||||
inputs.resp.resize(1);
|
||||
|
@ -962,7 +1001,6 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
|||
}
|
||||
|
||||
// inference NAR-len 0
|
||||
inputs.task = "tts";
|
||||
for ( auto l = 0; l < 8; ++l ) {
|
||||
inputs.rvq_l = l;
|
||||
output_tokens = generate( ctx, inputs, max_steps, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR, ctx->params.verbose );
|
||||
|
@ -971,7 +1009,6 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
|||
}
|
||||
// AR+NAR
|
||||
} else if ( modality == MODALITY_AR_NAR ){
|
||||
inputs.task = "tts";
|
||||
for ( auto l = 0; l < 8; ++l ) {
|
||||
inputs.rvq_l = l;
|
||||
output_tokens = generate( ctx, inputs, l == 0 ? max_duration : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR, ctx->params.verbose );
|
||||
|
@ -981,18 +1018,17 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
|
|||
|
||||
return inputs.resp;
|
||||
}
|
||||
void VALL_E_API vall_e_free( vall_e_context_t* ctx ) {
|
||||
void vall_e_free( vall_e_context_t* ctx ) {
|
||||
espeak_Terminate();
|
||||
encodec_free(ctx->encodec.ctx);
|
||||
llama_free(ctx->llama.ctx);
|
||||
llama_free_model(ctx->llama.model);
|
||||
ggml_free(ctx->io_map.ctx);
|
||||
ggml_free(ctx->io_map->ctx);
|
||||
delete ctx->io_map;
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
int main( int argc, char** argv ) {
|
||||
// to-do: parse CLI args
|
||||
|
||||
vall_e_context_params_t params;
|
||||
vall_e_args_t args;
|
||||
|
||||
|
|
|
@ -5,34 +5,100 @@
|
|||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
// external deps
|
||||
#include <llama.h>
|
||||
#include <encodec.h>
|
||||
#include <dr_wav.h>
|
||||
#include <espeak-ng/speak_lib.h>
|
||||
|
||||
// to-do: copy over import/export stuff from engine project (because I don't remember how I set it up in <uf/config.h>)
|
||||
#define VALL_E_API
|
||||
// handles defining platform specific macros and import/export decorators (copied from my engine's uf/config.h)
|
||||
#if defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__)
|
||||
// Windows
|
||||
#define VALL_E_ENV "Windows"
|
||||
#define VALL_E_ENV_WINDOWS 1
|
||||
#define VALL_E_ENV_HEADER "windows.h"
|
||||
#if defined(__CYGWIN__)
|
||||
#define to_string(var) string(var)
|
||||
#endif
|
||||
#ifndef _WIN32_WINNT
|
||||
#define _WIN32_WINNT 0x0600
|
||||
#endif
|
||||
#ifndef WINVER
|
||||
#define WINVER 0x0600
|
||||
#endif
|
||||
|
||||
#define VALL_E_IO_ROOT "./data/"
|
||||
#elif defined(linux) || defined(__linux)
|
||||
// Linux
|
||||
#define VALL_E_ENV "Linux"
|
||||
#define VALL_E_ENV_LINUX 1
|
||||
#define VALL_E_ENV_HEADER "linux.h"
|
||||
|
||||
#define VALL_E_IO_ROOT "./data/"
|
||||
#elif defined(__APPLE__) || defined(MACOSX) || defined(macintosh) || defined(Macintosh)
|
||||
// MacOS
|
||||
#define VALL_E_ENV "OSX"
|
||||
#define VALL_E_ENV_OSX 1
|
||||
#define VALL_E_ENV_HEADER "osx.h"
|
||||
|
||||
#define VALL_E_IO_ROOT "./data/"
|
||||
#elif defined(__FreeBSD__) || defined(__FreeBSD_kernel__)
|
||||
// FreeBSD
|
||||
#define VALL_E_ENV "FreeBSD"
|
||||
#define VALL_E_ENV_FREEBSD 1
|
||||
#define VALL_E_ENV_HEADER "freebsd.h"
|
||||
|
||||
#define VALL_E_IO_ROOT "./data/"
|
||||
#elif defined(__sh__)
|
||||
// Dreamcast
|
||||
#define VALL_E_ENV "Dreamcast"
|
||||
#define VALL_E_ENV_DREAMCAST 1
|
||||
#define VALL_E_ENV_HEADER "dreamcast.h"
|
||||
#include VALL_E_ENV_HEADER
|
||||
|
||||
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
|
||||
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
|
||||
#define _arch_dreamcast
|
||||
|
||||
#if !LLAMA_CPP_EXTENDED
|
||||
#include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
|
||||
#define VALL_E_IO_ROOT "/cd/"
|
||||
#else
|
||||
// Unsupported system
|
||||
#define VALL_E_ENV "Unknown"
|
||||
#define VALL_E_ENV_UNKNOWN 1
|
||||
#define VALL_E_ENV_HEADER "unknown.h"
|
||||
#warning Using "unknown"
|
||||
#error No support
|
||||
#endif
|
||||
|
||||
// to-do: clean up spaghetti enums
|
||||
const int EMBEDDING_MODE_PROM = 0;
|
||||
const int EMBEDDING_MODE_RESP_AR_NAR = 1;
|
||||
const int EMBEDDING_MODE_RESP_NAR_LEN = 2;
|
||||
#if !defined(VALL_E_STATIC)
|
||||
#if defined(VALL_E_ENV_WINDOWS)
|
||||
// Windows compilers need specific (and different) keywords for export and import
|
||||
#define VALL_E_API_EXPORT __declspec(dllexport)
|
||||
#define VALL_E_API_IMPORT __declspec(dllimport)
|
||||
// For Visual C++ compilers, we also need to turn off this annoying C4251 warning
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(disable : 4251)
|
||||
#endif
|
||||
#else // Linux, FreeBSD, Mac OS X
|
||||
#if __GNUC__ >= 4
|
||||
// GCC 4 has special keywords for showing/hidding symbols,
|
||||
// the same keyword is used for both importing and exporting
|
||||
#define VALL_E_API_EXPORT __attribute__ ((__visibility__ ("default")))
|
||||
#define VALL_E_API_IMPORT __attribute__ ((__visibility__ ("default")))
|
||||
#else
|
||||
// GCC < 4 has no mechanism to explicitely hide symbols, everything's exported
|
||||
#define VALL_E_API_EXPORT
|
||||
#define VALL_E_API_IMPORT
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// Static build doesn't need import/export macros
|
||||
#define VALL_E_API_EXPORT
|
||||
#define VALL_E_API_IMPORT
|
||||
#endif
|
||||
|
||||
const int INFERENCE_MODE_LEN = 0;
|
||||
const int INFERENCE_MODE_AR = 1;
|
||||
const int INFERENCE_MODE_NAR_DEMASK = 2;
|
||||
const int INFERENCE_MODE_NAR = 3;
|
||||
#ifdef VALL_E_EXPORTS
|
||||
#define VALL_E_API VALL_E_API_EXPORT
|
||||
#else
|
||||
#define VALL_E_API VALL_E_API_IMPORT
|
||||
#endif
|
||||
|
||||
const int MODALITY_AR_NAR = 0;
|
||||
const int MODALITY_NAR_LEN = 1;
|
||||
typedef llama_token token_t;
|
||||
typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t;
|
||||
|
||||
const int ENCODEC_FRAMES_PER_SECOND = 75;
|
||||
const int MAX_DURATION = ENCODEC_FRAMES_PER_SECOND * 12;
|
||||
|
@ -40,52 +106,16 @@ const int CTX_SIZE = 2048;
|
|||
const int N_THREADS = 8;
|
||||
const int N_GPU_LAYERS = 99;
|
||||
|
||||
typedef llama_token token_t;
|
||||
typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t;
|
||||
const int MODALITY_AR_NAR = 0;
|
||||
const int MODALITY_NAR_LEN = 1;
|
||||
|
||||
// stores embeddings + metadata for an embedding range
|
||||
struct io_t {
|
||||
std::string name;
|
||||
uint32_t start;
|
||||
uint32_t end;
|
||||
int32_t head_idx = -1;
|
||||
|
||||
int32_t n_embd = 0;
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
std::vector<float> embds = {};
|
||||
ggml_tensor* head = NULL;
|
||||
};
|
||||
|
||||
// stores the mappings between tokens, input embeddings, and output heads
|
||||
struct io_map_t {
|
||||
// model's original params
|
||||
int32_t n_embd = 0;
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
// mapping
|
||||
std::unordered_map<std::string, io_t> io = {};
|
||||
// context to store slices
|
||||
ggml_context* ctx = NULL;
|
||||
};
|
||||
|
||||
struct score_t {
|
||||
int32_t idx;
|
||||
float value;
|
||||
|
||||
bool operator<( const score_t& that ) const { return this->value < that.value; }
|
||||
};
|
||||
|
||||
struct merge_entry_t {
|
||||
std::u32string pre;
|
||||
std::u32string post;
|
||||
std::u32string resolved;
|
||||
|
||||
token_t pre_token;
|
||||
token_t post_token;
|
||||
token_t resolved_token;
|
||||
};
|
||||
// forward declarations
|
||||
struct io_map_t;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct encodec_context;
|
||||
|
||||
// model-specific parameters
|
||||
struct vall_e_context_params_t {
|
||||
std::string model_path = "./data/vall_e.gguf";
|
||||
std::string encodec_path = "./data/encodec.bin";
|
||||
|
@ -94,20 +124,22 @@ struct vall_e_context_params_t {
|
|||
int32_t ctx_size = CTX_SIZE;
|
||||
bool verbose = false;
|
||||
};
|
||||
// inference-specific arguments
|
||||
struct vall_e_args_t {
|
||||
std::string text = "Hello world.";
|
||||
std::string prompt_path = "./data/prom.wav";
|
||||
std::string output_path = "./data/resp.wav";
|
||||
std::string language = "en";
|
||||
std::string task = "tts";
|
||||
int modality = MODALITY_NAR_LEN;
|
||||
int max_steps = 30;
|
||||
int max_duration = MAX_DURATION;
|
||||
};
|
||||
// stores everything needed for vall_e.cpp
|
||||
// stores everything needed for vall_e.cpp at runtime
|
||||
struct vall_e_context_t {
|
||||
vall_e_context_params_t params;
|
||||
|
||||
io_map_t io_map;
|
||||
io_map_t* io_map = NULL; // pointer for reasons
|
||||
|
||||
struct {
|
||||
llama_model* model = NULL;
|
||||
|
@ -121,49 +153,26 @@ struct vall_e_context_t {
|
|||
// stores the raw inputs to be fed
|
||||
struct vall_e_inputs_t {
|
||||
std::string task = "tts";
|
||||
std::string lang = "en";
|
||||
|
||||
token_t rvq_l = 0;
|
||||
|
||||
std::vector<token_t> phn = {};
|
||||
token_t lang = 0;
|
||||
token_t rvq_l = 0;
|
||||
vall_e_audio_codes_t prom = {};
|
||||
vall_e_audio_codes_t resp = {};
|
||||
};
|
||||
|
||||
// helper tensor functions
|
||||
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor );
|
||||
//ggml_tensor* VALL_E_API view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket
|
||||
ggml_tensor* VALL_E_API view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 );
|
||||
void VALL_E_API print_tokens( const std::vector<token_t>& tokens, const std::string& prefix = "Tokens: " );
|
||||
|
||||
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds );
|
||||
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const vall_e_audio_codes_t& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
|
||||
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits );
|
||||
|
||||
// batch and inferencing
|
||||
void VALL_E_API batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} );
|
||||
void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& input, io_map_t& inputs_map, int mode );
|
||||
std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t& input, int max_tokens, int mode, bool verbose = true );
|
||||
|
||||
//
|
||||
std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language = "auto" );
|
||||
|
||||
// encodec helpers
|
||||
std::vector<float> VALL_E_API read_audio_from_disk( const std::string& path );
|
||||
void VALL_E_API write_audio_to_disk( const std::vector<float>& waveform, const std::string& path );
|
||||
VALL_E_API std::vector<float> read_audio_from_disk( const std::string& path );
|
||||
VALL_E_API void write_audio_to_disk( const std::vector<float>& waveform, const std::string& path );
|
||||
|
||||
std::vector<std::vector<int32_t>> VALL_E_API encode_audio( struct encodec_context* ectx, const std::vector<float>& waveform );
|
||||
std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes_2d );
|
||||
|
||||
// model-accessing helpers
|
||||
const io_t& VALL_E_API vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name );
|
||||
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name );
|
||||
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name );
|
||||
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );
|
||||
VALL_E_API std::vector<std::vector<int32_t>> encode_audio( struct encodec_context* ectx, const std::vector<float>& waveform );
|
||||
VALL_E_API std::vector<float> decode_audio( struct encodec_context* ectx, const vall_e_audio_codes_t& codes_2d );
|
||||
|
||||
// context management
|
||||
void VALL_E_API vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args );
|
||||
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args );
|
||||
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params );
|
||||
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang );
|
||||
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN );
|
||||
void VALL_E_API vall_e_free( vall_e_context_t* ctx );
|
||||
VALL_E_API void vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args );
|
||||
VALL_E_API bool vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args );
|
||||
VALL_E_API vall_e_context_t* vall_e_load( const vall_e_context_params_t& params );
|
||||
VALL_E_API vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang = "auto", const std::string& task = "tts" );
|
||||
VALL_E_API vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN );
|
||||
VALL_E_API void vall_e_free( vall_e_context_t* ctx );
|
||||
|
|
|
@ -309,7 +309,7 @@ def fold_inputs(
|
|||
return 0
|
||||
|
||||
if isinstance(prom, str):
|
||||
task = get_task_symmap()[f'<{input}>']
|
||||
task = get_task_symmap()[input]
|
||||
seq = torch.tensor([task_start + task], device=device, dtype=dtype)
|
||||
|
||||
input_ids[i].append( seq )
|
||||
|
@ -664,19 +664,19 @@ def get_tone_symmap():
|
|||
|
||||
def get_task_symmap():
|
||||
return {
|
||||
"<tts>": 0,
|
||||
"<tts-c>": 1,
|
||||
"<ns>": 2,
|
||||
"<sr>": 3,
|
||||
"<tse>": 4,
|
||||
"<soe>": 5,
|
||||
"<mask>": 6,
|
||||
"<eoe>": 7,
|
||||
"<stt>": 8,
|
||||
"tts": 0,
|
||||
"tts-c": 1,
|
||||
"ns": 2,
|
||||
"sr": 3,
|
||||
"tse": 4,
|
||||
"soe": 5,
|
||||
"mask": 6,
|
||||
"eoe": 7,
|
||||
"stt": 8,
|
||||
|
||||
"<len>": 0, # fake
|
||||
"<nse>": 6, # fake
|
||||
"<cse>": 6, # fake
|
||||
"len": 0, # fake
|
||||
"nse": 6, # fake
|
||||
"cse": 6, # fake
|
||||
}
|
||||
|
||||
def _replace_file_extension(path, suffix):
|
||||
|
@ -1330,7 +1330,7 @@ class Dataset(_Dataset):
|
|||
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
if f'<{task}>' not in self.task_symmap:
|
||||
if task not in self.task_symmap:
|
||||
raise Exception(f'Task not defined: {task}')
|
||||
|
||||
# Base TTS (<text><prompt> => <resp>)
|
||||
|
|
148
vall_e/export.py
148
vall_e/export.py
|
@ -40,6 +40,20 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
|
|||
}
|
||||
}
|
||||
|
||||
# cleanup duplicate IDs because convert_hf_to_gguf.py does not like this
|
||||
# get unique tokens
|
||||
for k, v in tokenizer["model"]["vocab"].items():
|
||||
if k not in tokenizer_vocab:
|
||||
tokenizer_vocab[k] = v
|
||||
# override if its in a merge
|
||||
for k, v in tokenizer["model"]["vocab"].items():
|
||||
for m in tokenizer["model"]["merges"]:
|
||||
if k in m:
|
||||
tokenizer_vocab[k] = v
|
||||
break
|
||||
tokenizer["model"]["vocab"] = {}
|
||||
|
||||
|
||||
lang_map = [
|
||||
"en",
|
||||
"ja",
|
||||
|
@ -165,7 +179,7 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
|
|||
model_dict['lm_head.weight'] = classifier_dict['weight']
|
||||
if classifier_bias:
|
||||
model_dict['lm_head.bias'] = classifier_dict['bias']
|
||||
|
||||
|
||||
# write files in an HF compatible way
|
||||
out_dir = cfg.rel_path / "hf"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -217,131 +231,6 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
|
|||
|
||||
return state_dict
|
||||
|
||||
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
|
||||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||
@torch.no_grad()
|
||||
def convert_to_hf_custom( state_dict, config = None, save_path = None ):
|
||||
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
|
||||
|
||||
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
|
||||
n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0]
|
||||
n_len_tokens = 11
|
||||
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
|
||||
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
||||
|
||||
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
|
||||
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
|
||||
|
||||
# the new tokenizer to use
|
||||
tokenizer = {}
|
||||
tokenizer_vocab = {}
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
|
||||
if not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||
if tokenizer_path.exists():
|
||||
tokenizer = json_read( tokenizer_path )
|
||||
else:
|
||||
tokenizer = {
|
||||
"model": {
|
||||
"vocab": get_phone_symmap()
|
||||
}
|
||||
}
|
||||
|
||||
lang_map = [
|
||||
"en",
|
||||
"ja",
|
||||
"de",
|
||||
"fr",
|
||||
"zh",
|
||||
"ko",
|
||||
]
|
||||
task_map = [
|
||||
"tts",
|
||||
"tts-c",
|
||||
"ns",
|
||||
"sr",
|
||||
"tse",
|
||||
"soe",
|
||||
"mask",
|
||||
"eoe",
|
||||
"stt",
|
||||
]
|
||||
|
||||
model_dict = {}
|
||||
# filter out the underlying model weights and extract them
|
||||
for k in state_dict['module'].keys():
|
||||
if not k.startswith('model.'):
|
||||
continue
|
||||
model_dict[k] = state_dict['module'][k].clone()
|
||||
|
||||
# cringe
|
||||
for l in range(11):
|
||||
model_dict[f'classifiers.{l}.weight'] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
for l in range(8):
|
||||
model_dict[f"embeddings.proms.{l}.weight"] = state_dict['module'][f"proms_emb.embeddings.{l}.weight"]
|
||||
for l in range(9):
|
||||
model_dict[f"embeddings.resps.{l}.weight"] = state_dict['module'][f"resps_emb.embeddings.{l}.weight"]
|
||||
|
||||
model_dict["embeddings.aux.0.weight"] = state_dict['module']["text_emb.weight"]
|
||||
model_dict["embeddings.aux.1.weight"] = state_dict['module']["rvq_l_emb.weight"]
|
||||
model_dict["embeddings.aux.2.weight"] = state_dict['module']["langs_emb.weight"]
|
||||
model_dict["embeddings.aux.3.weight"] = state_dict['module']["tasks_emb.weight"]
|
||||
model_dict["embeddings.aux.4.weight"] = state_dict['module']["len_emb.weight"]
|
||||
model_dict["embeddings.aux.5.weight"] = state_dict['module']["tones_emb.weight"]
|
||||
model_dict["embeddings.aux.6.weight"] = state_dict['module']["sep"].unsqueeze(0)
|
||||
|
||||
# write files in an HF compatible way
|
||||
out_dir = cfg.rel_path / "hf"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# write weights
|
||||
torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" )
|
||||
# write tokenizer.json
|
||||
tokenizer['model']['vocab'] |= tokenizer_vocab
|
||||
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
|
||||
# write tokenizer_config.json
|
||||
json_write({
|
||||
"added_tokens": tokenizer['added_tokens'],
|
||||
"bos_token": "<bos>",
|
||||
"eos_token": "</eos>",
|
||||
"clean_up_tokenization_spaces": True,
|
||||
"model_input_names": [
|
||||
"input_ids",
|
||||
"attention_mask"
|
||||
],
|
||||
"tokenizer_class": "PreTrainedTokenizerFast"
|
||||
}, out_dir / "tokenizer_config.json", pretty=True)
|
||||
# write config.json
|
||||
json_write({
|
||||
"architectures": [
|
||||
"ValleLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": model_dim,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": model_dim * 4,
|
||||
"max_position_embeddings": 75 * 60 * 5,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 12,
|
||||
"num_key_value_heads": 16,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 10000.0,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.0",
|
||||
"use_cache": False,
|
||||
"vocab_size": 256
|
||||
}, out_dir / "config.json", pretty=True )
|
||||
|
||||
return state_dict
|
||||
|
||||
# yanks a LoRA from the training checkpoint
|
||||
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
||||
if dtype is None:
|
||||
|
@ -412,8 +301,7 @@ def moe_ify( state_dict, config = cfg.model, save_path = None, dtype = None ):
|
|||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--hf-llama", action='store_true', default=None) # convert to HF-style llama model
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF LLaMA
|
||||
parser.add_argument("--export-lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
|
||||
|
@ -441,9 +329,7 @@ def main():
|
|||
engines = load_engines(training=False) # to ignore loading optimizer state
|
||||
|
||||
callback = None
|
||||
if args.hf_llama:
|
||||
callback = convert_to_hf_llama
|
||||
elif args.hf:
|
||||
if args.hf:
|
||||
callback = convert_to_hf_custom
|
||||
elif args.export_lora:
|
||||
callback = extract_lora
|
||||
|
|
|
@ -59,40 +59,24 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
|
|||
|
||||
|
||||
def get_model(config, training=True, **model_kwargs):
|
||||
from .ar_nar import AR_NAR # import here because reasons
|
||||
name = config.name
|
||||
|
||||
if config.experimental.hf:
|
||||
from .experimental import Model as Experimental
|
||||
model = Experimental(
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
|
||||
d_model=config.dim,
|
||||
n_layers=config.layers,
|
||||
n_heads=config.heads,
|
||||
p_dropout=config.dropout,
|
||||
|
||||
config = config,
|
||||
**model_kwargs
|
||||
)
|
||||
else:
|
||||
from .ar_nar import AR_NAR
|
||||
model = AR_NAR(
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
d_model=config.dim,
|
||||
n_heads=config.heads,
|
||||
n_layers=config.layers,
|
||||
n_experts=config.experts,
|
||||
|
||||
p_dropout=config.dropout,
|
||||
|
||||
l_padding = config.input_alignment,
|
||||
|
||||
training = training,
|
||||
config = config,
|
||||
**model_kwargs
|
||||
)
|
||||
model = AR_NAR(
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
d_model=config.dim,
|
||||
n_heads=config.heads,
|
||||
n_layers=config.layers,
|
||||
n_experts=config.experts,
|
||||
|
||||
p_dropout=config.dropout,
|
||||
|
||||
l_padding = config.input_alignment,
|
||||
|
||||
training = training,
|
||||
config = config,
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
||||
|
|
|
@ -694,7 +694,6 @@ class Base(nn.Module):
|
|||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
print( config )
|
||||
self.model = LlamaClass(config)
|
||||
|
||||
# replace with desired attention
|
||||
|
@ -951,7 +950,7 @@ class Base(nn.Module):
|
|||
# Base-line TTS task
|
||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
|
||||
if task_type in get_task_symmap() and task_type not in special_tasks:
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
@ -1092,7 +1091,7 @@ class Base(nn.Module):
|
|||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_embedding( input, quant_level ):
|
||||
if isinstance(input, str):
|
||||
return self.tasks_emb( torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16) )
|
||||
return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) )
|
||||
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
if self.version <= 4:
|
||||
|
@ -1348,14 +1347,8 @@ class Base(nn.Module):
|
|||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ self.ignore_index ], device=device, dtype=torch.int16)
|
||||
|
||||
return torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
|
||||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
|
|
|
@ -1,574 +0,0 @@
|
|||
"""
|
||||
This is an experiment to:
|
||||
* entertain a thought to try and abide by HF's transformers API (to benefit from caching better)
|
||||
* conform to a single embedding (instead of a bunch of them) by folding/unfolding inputs
|
||||
* stop trying to make a mixed AR+NAR model work since it seems lobotomized if I keep trying to enforce both recurrent and parallel inferencing (despite a penalty cost)
|
||||
+ I will not cave and go with codebook patterns, not yet.
|
||||
"""
|
||||
|
||||
from ..config import cfg
|
||||
|
||||
from ..data import fold_inputs, unfold_outputs
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
import random
|
||||
import math
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from einops import rearrange
|
||||
from tqdm import trange
|
||||
|
||||
from .arch import *
|
||||
|
||||
if cfg.model.arch_type not in AVAILABLE_ARCHES:
|
||||
raise ValueError(f"Requesting arch `{cfg.model.arch_type}` but not available")
|
||||
|
||||
if cfg.model.arch_type in ["mamba","mamba2"]:
|
||||
LlmArchClass = MambaLMHeadModel
|
||||
elif cfg.model.arch_type == "llama":
|
||||
LlmArchClass = LlamaForCausalLM
|
||||
elif cfg.model.arch_type == "retnet":
|
||||
LlmArchClass = RetNetForCausalLM
|
||||
else:
|
||||
raise ValueError(f"Requesting arch `{cfg.model.arch_type}` but not available")
|
||||
|
||||
class Model(LlmArchClass):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
n_text_tokens = 256,
|
||||
n_audio_tokens = 1024,
|
||||
|
||||
d_model=1024,
|
||||
n_layers=12,
|
||||
n_heads=16,
|
||||
p_dropout=0.1,
|
||||
|
||||
config = cfg.model,
|
||||
):
|
||||
self.hyper_config = config
|
||||
|
||||
hf_attention = config.attention if config is not None else None
|
||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||
|
||||
if hf_attention == "auto":
|
||||
if AVAILABLE_ATTENTIONS:
|
||||
hf_attention = AVAILABLE_ATTENTIONS[0]
|
||||
else:
|
||||
hf_attention = "eager"
|
||||
|
||||
if hf_attention == "xformers":
|
||||
hf_attention = "mem_efficient"
|
||||
|
||||
text_start = 0
|
||||
text_end = text_start + config.text_tokens
|
||||
|
||||
lang_start = text_end
|
||||
lang_end = lang_start + config.langs
|
||||
|
||||
rvq_start = lang_end
|
||||
rvq_end = rvq_start + config.resp_levels
|
||||
|
||||
prom_start = rvq_end
|
||||
prom_end = prom_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
task_start = prom_end
|
||||
task_end = task_start + config.tasks
|
||||
|
||||
tone_start = task_end
|
||||
tone_end = tone_start + config.tones
|
||||
|
||||
resp_start = tone_end
|
||||
resp_end = resp_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
vocab_size = resp_end
|
||||
|
||||
if config.arch_type == "llama":
|
||||
super().__init__(config=LlamaConfig(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=cfg.dataset.frames_per_second * config.max_levels * 60, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout,
|
||||
num_key_value_heads=n_heads,
|
||||
sliding_window=cfg.dataset.frames_per_second * config.max_levels * 12,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=hf_attention,
|
||||
))
|
||||
|
||||
if gradient_checkpointing:
|
||||
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
elif config.arch_type == "retnet":
|
||||
super().__init__(config=RetNetConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_embed_dim=d_model,
|
||||
decoder_value_embed_dim =d_model * 2,
|
||||
decoder_retention_heads=n_heads,
|
||||
decoder_ffn_embed_dim=d_model * 4,
|
||||
decoder_layers=n_layers,
|
||||
dropout=p_dropout,
|
||||
checkpoint_activations=gradient_checkpointing,
|
||||
activation_fn="gelu",
|
||||
use_layernorm=False,
|
||||
use_biases=False,
|
||||
use_glu=True,
|
||||
|
||||
#chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||
#recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
#no_output_layer=True,
|
||||
#rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||
|
||||
decoder_normalize_before=True,
|
||||
))
|
||||
elif config.arch_type in ["mamba","mamba2"]:
|
||||
super().__init__(config=MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
d_intermediate=0, # d_model*4,
|
||||
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if config.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=False,
|
||||
))
|
||||
|
||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
self.accuracy_metric = None if True else MulticlassAccuracy(
|
||||
vocab_size,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if cfg.model.arch_type in ["mamba","mamba2"]:
|
||||
kwargs["cg"] = True
|
||||
|
||||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
|
||||
if "do_sample" in kwargs:
|
||||
kwargs.pop("do_sample")
|
||||
|
||||
if "min_length" in kwargs:
|
||||
kwargs.pop("min_length")
|
||||
|
||||
"""
|
||||
if "position_ids" in kwargs:
|
||||
kwargs.pop("position_ids")
|
||||
|
||||
if "max_new_tokens" in kwargs:
|
||||
kwargs.pop("max_new_tokens")
|
||||
|
||||
if "max_length" not in kwargs:
|
||||
kwargs["max_length"] = 500 * (self.hyper_config.resp_levels if self.hyper_config.experimental.interleave else 1)
|
||||
|
||||
if "num_last_tokens" not in kwargs:
|
||||
kwargs["num_last_tokens"] = self.hyper_config.experimental.causal_size
|
||||
"""
|
||||
|
||||
input_ids = kwargs.pop("input_ids")
|
||||
attention_mask = kwargs.pop("attention_mask", None)
|
||||
position_ids = kwargs.pop("position_ids", None)
|
||||
|
||||
stop_token = kwargs.pop("eos_token_id", 3)
|
||||
max_steps = kwargs.pop("max_new_tokens", 500)
|
||||
|
||||
device = input_ids.device
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
sequence_list = [ inputs for inputs in input_ids ]
|
||||
position_list = [ positions for positions in position_ids ]
|
||||
|
||||
start_positions = [ inputs.shape[0] for inputs in input_ids ]
|
||||
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
config = self.hyper_config
|
||||
state = None
|
||||
disable_tqdm = False
|
||||
causal_size = config.experimental.causal_size
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, causal_size), desc="AR", disable=disable_tqdm):
|
||||
output = super().forward(
|
||||
input_ids=torch.stack(sequence_list),
|
||||
#attention_mask=attention_mask,
|
||||
#past_key_values=state,
|
||||
#position_ids=torch.stack(position_list),
|
||||
#use_cache=False,
|
||||
#return_dict=False
|
||||
)
|
||||
|
||||
logits = output[0]
|
||||
# state = output[1]
|
||||
|
||||
r = [ logit[-causal_size:].argmax(dim=1) for logit in logits ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
|
||||
last_position_id = position_list[i][-1].item() + 1
|
||||
sequence_list[i] = torch.cat([ sequence_list[i], ri.to(device) ], dim=0)
|
||||
#position_list[i] = torch.cat([ position_list[i], torch.tensor([ last_position_id + _ for _ in range( ri.shape[0] ) ], device=device, dtype=torch.int32) ])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
def _prune(l: Tensor, stop = stop_token):
|
||||
indices = (l == stop).nonzero()
|
||||
|
||||
if len(indices) == 0:
|
||||
return l
|
||||
|
||||
return l[: indices.min().item()]
|
||||
|
||||
sequence_list = [ _prune(seq[start_positions[i]:], stop_token) for i, seq in enumerate(sequence_list) ]
|
||||
return torch.stack(sequence_list)
|
||||
|
||||
"""
|
||||
return super().generate(*args, **kwargs)
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
config = self.hyper_config
|
||||
|
||||
if "text_list" in kwargs:
|
||||
text_list = kwargs.pop("text_list", None)
|
||||
proms_list = kwargs.pop("proms_list", None)
|
||||
resps_list = kwargs.pop("resps_list", None)
|
||||
lang_list = kwargs.pop("lang_list", None)
|
||||
tone_list = kwargs.pop("tone_list", None)
|
||||
|
||||
training = kwargs.pop("training", False)
|
||||
steps = kwargs.pop("steps", 500)
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
if training:
|
||||
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask, position_ids = fold_inputs(
|
||||
text_list=text_list,
|
||||
prom_list=proms_list,
|
||||
resp_list=resps_list,
|
||||
targ_list=resps_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
target_ids, target_attention_mask, target_position_ids = fold_inputs(
|
||||
text_list=text_list,
|
||||
prom_list=proms_list,
|
||||
resp_list=resps_list,
|
||||
targ_list=resps_list,
|
||||
quant_levels=quant_levels,
|
||||
ignore_index=-100
|
||||
)
|
||||
return self.forward(
|
||||
input_ids=input_ids,
|
||||
labels=target_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
if config.experimental.interleave:
|
||||
input_ids, attention_mask, position_ids = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
eos_token_id=3,
|
||||
do_sample=True,
|
||||
max_new_tokens=steps*config.max_levels,
|
||||
)
|
||||
return unfold_outputs( output )["resp_list"]
|
||||
|
||||
resps_list = [ [] for _ in range(batch_size) ]
|
||||
for l in range(config.max_levels):
|
||||
quant_levels = [ l for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask, position_ids = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
eos_token_id=3,
|
||||
do_sample=True,
|
||||
max_new_tokens=steps,
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
if l == 0:
|
||||
steps = 0
|
||||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
length = resp.shape[-1]
|
||||
|
||||
# store length
|
||||
if l == 0:
|
||||
steps = max( steps, length )
|
||||
# pad
|
||||
else:
|
||||
resp = resp[:steps]
|
||||
if length < steps:
|
||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||
resps_list[batch].append( resp )
|
||||
|
||||
for i, resp in enumerate( resps_list ):
|
||||
resps_list[i] = torch.stack( resp ).t()
|
||||
|
||||
return resps_list
|
||||
|
||||
if config.arch_type in ["mamba","mamba2"]:
|
||||
kwargs.pop("attention_mask", None)
|
||||
|
||||
labels = kwargs.pop("labels", None)
|
||||
quant_levels = kwargs.pop("quant_levels", None)
|
||||
|
||||
output = super().forward(*args, **kwargs)
|
||||
logits = output.logits
|
||||
|
||||
# i HATE the correct way
|
||||
if labels is not None:
|
||||
if quant_levels is None:
|
||||
quant_levels = [0 for _ in range(labels.shape[0])]
|
||||
|
||||
# predict the next token for AR, else predict in place
|
||||
loss = sum([ F.cross_entropy(
|
||||
logit[:-config.experimental.causal_size, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
||||
label[config.experimental.causal_size:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
||||
ignore_index=-100
|
||||
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
|
||||
|
||||
self.loss = dict(
|
||||
nll = loss,
|
||||
)
|
||||
|
||||
if self.accuracy_metric is not None:
|
||||
self.stats = dict(
|
||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||
)
|
||||
|
||||
"""
|
||||
if config.loss_factors:
|
||||
sep = 3
|
||||
# determine specific sections to focus on
|
||||
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
|
||||
|
||||
text_index = 0
|
||||
resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||
|
||||
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
|
||||
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
|
||||
|
||||
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
|
||||
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
|
||||
|
||||
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) / len(logits_text) * self.hyper_config.loss_factor("text")
|
||||
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) / len(logits_resp) * self.hyper_config.loss_factor("resp")
|
||||
|
||||
self.loss = dict(
|
||||
text = loss_text,
|
||||
resp = loss_resp,
|
||||
)
|
||||
|
||||
if self.accuracy_metric is not None:
|
||||
self.stats = dict(
|
||||
acc = dict(
|
||||
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
|
||||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
return output
|
||||
|
||||
def example_usage():
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_100
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..emb.qnt import decode_to_file, unload_model
|
||||
from ..engines import Engine
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
device = "cuda"
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.max_levels, :].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
]
|
||||
prom_list = [
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resp_list = [
|
||||
qnt[:, :].to(device),
|
||||
#qnt[cfg.dataset.frames_per_second:, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
|
||||
text_list = text_list[:1]
|
||||
prom_list = prom_list[:1]
|
||||
resp_list = resp_list[:1]
|
||||
|
||||
kwargs = {}
|
||||
model = Model(**kwargs).to(device)
|
||||
steps = 50 # 100 if cfg.model.experimental.interleave else 300
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
||||
|
||||
if cfg.optimizations.dadaptation:
|
||||
# do not combine the two
|
||||
if scheduler == "schedulefree":
|
||||
scheduler = ""
|
||||
|
||||
learning_rate = 1.0
|
||||
|
||||
if optimizer == "prodigy":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0
|
||||
|
||||
optimizer = ml.Prodigy
|
||||
elif optimizer == "adagrad":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-2
|
||||
|
||||
optimizer = ml.Adagrad
|
||||
elif optimizer == "adamw":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.AdamW
|
||||
elif optimizer == "sdg":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.SGD
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
if scheduler == "schedulefree":
|
||||
if isinstance(optimizer, ml.AdamW):
|
||||
scheduler = ml.schedulefree.AdamWScheduleFree
|
||||
elif isinstance(optimizer, ml.SGD):
|
||||
scheduler = ml.schedulefree.SGDScheduleFree
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
_logger.info(f"Scheduler: {scheduler}")
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model = ml.replace_linear( model )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||
model = ml.replace_embedding( model )
|
||||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
_logger.info(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||
engine.eval()
|
||||
|
||||
resp_list = model( text_list=text_list, proms_list=prom_list )
|
||||
|
||||
for i, batch in enumerate(resp_list):
|
||||
_ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=prom_list, resps_list=resp_list, training=True)
|
||||
stats |= engine.gather_attribute("stats")
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
#sample("init", 5)
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
|
@ -128,7 +128,7 @@ def get_languages():
|
|||
return list(get_lang_symmap().keys()) + ["auto"]
|
||||
|
||||
def get_tasks():
|
||||
return ["tts", "sr", "nr", "vc"]
|
||||
return ["tts", "sr", "ns", "vc"]
|
||||
|
||||
#@gradio_wrapper(inputs=layout["dataset"]["inputs"].keys())
|
||||
def load_sample( speaker ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user