This commit is contained in:
mrq 2024-12-26 21:42:17 -06:00
parent 25a02f2c3f
commit 9b0d2ccbe1
13 changed files with 396 additions and 937 deletions

4
.gitignore vendored
View File

@ -10,6 +10,6 @@ __pycache__
/.nltk
/vall_e.cpp/data
/vall_e.cpp/include
/vall_e.cpp/libs
/vall_e.cpp/lib
/vall_e.cpp/*.o
/vall_e.cpp/vall_e
/vall_e.cpp/vall_e

View File

@ -121,6 +121,9 @@ With attention-based transformers, most embeddings can serve as a token itself a
Other solutions such as TorToiSe makes use of additional embeddings/classifiers for each portion of the sequence as well.
Other solutions will rely on conditioning latents or extracted features as the input. This *technically* isn't necessary since portions of the model seem to be allocated as an encoder anyways from the embeddings to some arbitrary depth, and as a decoder from some arbitrary depth to the output heads.
* This might also mean it makes more sense to increase the model's size in-post by injecting new layers in the middle where it's outside these pseudo-encoder/decoder layers where it won't make any difference.
### Classifiers
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
@ -152,7 +155,7 @@ In reality, this seems to help govern the accent / general mannerisms associated
* Consequently, since this does tie to accents more, ***extreme*** attention is to be paid to the dialects being trained against, instead of naively grouping, say, all of Spanish to one language code.
* unfortunately, this does mean that audio annotated as English is dialect/accent-agnostic, per the dataset.
This embedding probably helps the model with being able to perform cross-lingual outputs, but I did not do any experimentations on a model without this, as the reference `ar+nar-llama-8` was trained with this from the beginning with the small Japanese in my dataset anyhow (and maybe the `ar+nar-retnet-8` experiment).
Some checkpoints of the model needs this for cross-lingual output, but the current checkpoints of the model doesn't seem to do this due to the attention heads deriving the language/accent from the phoneme sequences themselves rather than the language token due to a careless oversight.
#### Tone Embedding
@ -162,6 +165,8 @@ Should, since I do not actually make use of this anywhere, and the model is not
This should most definitely help the model identify tone strongly even without needing to annotate for it, but it does an adequate job already with maintaining tone from a given input prompt.
I imagine, like language/accent, this gets derived from the phoneme sequence itself rather than a guidance token.
### Audio Embeddings
However, due to the nature of the encoded audio, embedding the audio tokens requires the dark arts, as we use audio both as an input prompt (`prom`) for guidance, and as an output response (`resp`).
@ -230,12 +235,16 @@ In practice, this task is already implemented by providing the input audio to de
I imagine training for this task will better help the model understand what is noise and what isn't, and can better strongly-er map utterances from the input audio prompt to use in the output, delivering better prompt adherance.
* This also might help serve in helping the model identify effects applied to an utterance, and being able to maintain it in normal `tts` tasks, such as reverb or the audio quality itself (the "acoustic environment").
This task can be briefly trained for decent results in-post.
##### Speech Removal
This task `sr` aims to remove speech from a given audio, effectively serving as the reverse of denoising.
As state above, this should help the model better identify what is noise and what isn't.
This task can be briefly trained for decent results in-post.
##### Target Speech Extraction
This task `tse` aims to "extract" an utterance from audio containing other speakers, effective diarizing an utterance.
@ -258,6 +267,8 @@ The length predictor `len` task is required for a pure NAR model.
This task will naively output a zero, then the length in base-10, followed by a stop token.
This works because the model can already derive the length of a sequence when autoregressively decoding through the probability of emitting a `<stop>` token.
#### Speech-to-Text
The speech-To-text `stt` task transcribes a given piece of audio, by taking an input encoded audio, and outputting the text transcription.
@ -274,11 +285,13 @@ This task will follow a reverse sequence of `<audio><language><RVQ level><output
The model can be prompted in creative ways to yield some interesting behaviors:
* prompting without an input audio prompt will have the model generate a random voice ~~at the "cost" of some unintelligible utterance at the beginning of the output response (despite doing no promptless training)~~.
* classifier-free-guidance-aware training does fix this, but this property emerges without it.
* the AR is much better with this property, as the `NAR-len` gets crusty sometimes.
* the AR is much better with this property, as the `NAR-len` gets crusty sometimes as it will keep demasking on crust.
* prompting with an input text prompt being the transcription of the input audio prompt will have the response follow very closely to the input prompt (despite not doing input=output training).
* this should allow for easy transcription editing without much fuss.
* the `NAR-len` greatly exhibits this property, although it sometimes does keep any noise in the background.
* extra care is required when doing this, as some checkpoints of the model will degrade completely the moment the prompt can't be directly referenced.
* training without a language token will have the model derive the target language/accent from the phoneme sequence itself (it is a language model after all)
* voice conversion is *possible* through demasking with the source prompt as the mask, but the current inferencing mechanism yields crust at the end of the output
# `models/*`
@ -288,7 +301,7 @@ This folder contains scripts relating to models and code for VALL-E use, from th
This script implements Low-Ranking Adapters, to allow for cheaper and easier finetuning of existing modules.
At the moment, two approaches are offered, through replacing `nn.Linear` outright, or parameterizing a `nn.Liner`. The latter is used by default(?).
At the moment, two approaches are offered, through replacing `nn.Linear` outright, or parameterizing a `nn.Linear`. The latter is used by default(?).
## `models/base.py`
@ -303,6 +316,8 @@ This script implements the core underlying model for VALL-E. This handle:
This script aims to implement everything as required per VALL-E agnostically, to allow for different implementations to contain little extra code.
A very naive implementation of using the model can be found under the `__main__` invocation.
## `models/ar_nar.py`
This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively, if requested.
@ -312,14 +327,6 @@ For training, this model handles preparing the batch provided through the datalo
For inferencing, this will dynamically inference depending on the arguments provided.
## `models/experimental.py`
This script implements VALL-E as a mostly-HuggingFace compatible model, where it handles processing tokens as a uniform sequence of IDs.
This mostly serves as an experiment to see what is required to do so, for possible future implementations requiring just `llama.cpp` and `encodec.cpp`, and to provide a pure HF-compatible implementation.
Use of this is governed through `cfg.model.experimental.hf = True`
## `models/arch/*`
This folder contains scripts, I've either written myself or properly attributed to, that provide or modify existing modules of a given model.
@ -394,7 +401,7 @@ If I rememer right, it just simply provides gradient checkpointing.
### `models/arch/mixtral.py`
Like `llama.py`, this provides modifications to Mixtral through `transformers`.
Like `llama.py`, this provides modifications to Mixtral through `transformers`. However, most of the niceties from `llama.py` are not available here as it's not the core backend.
Primarily, this is to address a bug with batch sizes > 1, and to use a different attention mechanism.
* to-do: this is out of date from `llama.py`'s modified attention class.

View File

@ -1,15 +1,22 @@
ifeq ($(PREFIX),)
PREFIX := /usr/local
endif
CXX = g++
INCS += -I./include
LIBS += -L./libs
LIBS += -L./lib
LINKS += -lggml -lggml-base -lllama -lencodec -lespeak-ng
FLAGS += -march=native -O3
FLAGS += -march=native -O3 -DVALL_E_EXPORTS
SRCS := $(shell find ./ -name "*.cpp")
OBJS += $(patsubst %.cpp,%.o,$(SRCS))
TARGET = vall_e
TARGET_LIB = lib$(TARGET).so
TARGET_HEADER = $(TARGET).h
%.o: %.cpp
$(CXX) $(FLAGS) $(INCS) -c $< -o $@
@ -17,6 +24,19 @@ TARGET = vall_e
$(TARGET): $(OBJS)
$(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET)
$(TARGET_LIB): $(OBJS)
$(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET_LIB)
all: $(TARGET_LIB) $(TARGET)
lib: $(TARGET_LIB)
install:
cp $(TARGET) $(PREFIX)/bin/$(TARGET)
-cp $(TARGET_LIB) $(PREFIX)/lib/$(TARGET_LIB)
cp $(TARGET_HEADER) $(PREFIX)/include/$(TARGET_HEADER)
clean:
@-rm -f $(OBJS)
@-rm -f $(TARGET)
@-rm -f $(TARGET)
@-rm -f $(TARGET_LIB)

View File

@ -2,15 +2,15 @@
This is an implementation that makes use of [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [encodec.cpp](https://github.com/PABannier/encodec.cpp).
At the moment it's ***very*** work in progress.
Model weights can be found at [`ecker/vall-e@gguf`](https://huggingface.co/ecker/vall-e/tree/gguf).
Model weights can:
* be found at [`ecker/vall-e@gguf`](https://huggingface.co/ecker/vall-e/tree/gguf)
* converted with `vall_e.export --yaml=./model_path/config.yaml --hf`, then running `python3 /path/to/your/llama.cpp/convert_hf_to_gguf ./model_path/hf/`
## Build
Populate `./include/` with the `ggml`, `llama.cpp`, and `encodec.cpp` headers.
Populate `./libs/` with the compiled libraries of `llama.cpp`, `encodec.cpp`, and `espeak-ng`.
Populate `./lib/` with the compiled libraries of `llama.cpp`, `encodec.cpp`, and `espeak-ng` (if not already in your `LD_LIBRARY_PATH`).
Run `make`.
@ -23,7 +23,7 @@ Run `make`.
## To-Do
* [x] converted model to GGUF
* [ ] convert it without modifying any of the existing code, as the tokenizer requires some care
* [x] convert it without modifying any of the existing code, as the tokenizer requires some care
* [x] basic framework
* [x] load the quantized model
* [x] orchestrate the required embeddings
@ -45,6 +45,11 @@ Run `make`.
* [x] a functional CLI
* [x] actually make it work
* [x] clean up to make the code usable elsewhere
* [x] configured to allow for being used as a lib
* (I do need to validate this in my engine project, but that's in MSYS2)
* [ ] feature parity with the PyTorch version
* [ ] vocos
* [ ] additional tasks (`stt`, `ns`, `sr`, samplers)
* [ ] additional tasks
* [ ] `stt`
* [x] `ns` / `sr`
* [ ] samplers

93
vall_e.cpp/vall_e-impl.h Normal file
View File

@ -0,0 +1,93 @@
#pragma once
// stores all the backend stuff
// external deps
#include <llama.h>
#include <encodec.h>
#include <dr_wav.h>
#include <espeak-ng/speak_lib.h>
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
#if !LLAMA_CPP_EXTENDED
#include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
#endif
// to-do: clean up spaghetti enums
const int EMBEDDING_MODE_PROM = 0;
const int EMBEDDING_MODE_RESP_AR_NAR = 1;
const int EMBEDDING_MODE_RESP_NAR_LEN = 2;
const int INFERENCE_MODE_LEN = 0;
const int INFERENCE_MODE_AR = 1;
const int INFERENCE_MODE_NAR_DEMASK = 2;
const int INFERENCE_MODE_NAR = 3;
// stores metadata for inputs/outputs
struct io_t {
std::string name;
uint32_t start;
uint32_t end;
int32_t head_idx = -1;
int32_t n_embd = 0;
int32_t n_vocab = 0;
std::vector<float> embds = {};
ggml_tensor* head = NULL;
};
// stores the mappings between tokens, input embeddings, and output heads
struct io_map_t {
// model's original params
int32_t n_embd = 0;
int32_t n_vocab = 0;
// mapping
std::unordered_map<std::string, io_t> io = {};
// context to store slices
ggml_context* ctx = NULL;
};
// used for top-k (mainly for demasking)
struct score_t {
int32_t idx;
float value;
bool operator<( const score_t& that ) const { return this->value < that.value; }
};
// handles storing metadata for token merges
struct merge_entry_t {
std::u32string pre;
std::u32string post;
std::u32string resolved;
token_t pre_token;
token_t post_token;
token_t resolved_token;
};
// helper tensor functions
std::vector<float> read_2d_tensor( struct ggml_tensor* tensor );
//ggml_tensor* view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket
ggml_tensor* view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 );
void print_tokens( const std::vector<token_t>& tokens, const std::string& prefix = "Tokens: " );
std::vector<std::vector<float>> map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds );
std::vector<std::vector<float>> sum_embeddings( const vall_e_audio_codes_t& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
std::vector<float> soft_max( int n_logits, const float* logits );
// batch and inferencing
void batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} );
void fill_batch( llama_batch& batch, vall_e_inputs_t& input, io_map_t& inputs_map, int mode );
std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& input, int max_tokens, int mode, bool verbose = true );
// (handles text)
std::vector<token_t> phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language = "auto" );
// model-accessing helpers
const io_t& vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name );
const float* vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name );
int32_t vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name );
void vall_e_inputs_map_init( io_map_t&, llama_model* model );

View File

@ -1,5 +1,6 @@
#define DR_WAV_IMPLEMENTATION
#include "vall_e.h"
#include "vall_e-impl.h" // stores everything that isn't necessary for exporting
#include <cmath>
#include <cstdio>
@ -39,19 +40,43 @@ io_t io_ranges[] = {
{ "resps|NAR:0:0", 16677, 17702, 8 },
};
// stored here because I tokenize the merges
// I can't be assed to figure out the tokenizer right now
// lang map
std::unordered_map<std::string, token_t> lang_map = {
{ "en", 0 },
{ "ja", 1 },
{ "de", 2 },
{ "fr", 3 },
{ "zh", 4 },
{ "ko", 5 },
};
std::unordered_map<std::string, token_t> task_map = {
{ "tts", 0 },
{ "tts-c", 1 },
{ "ns", 2 },
{ "sr", 3 },
{ "tse", 4 },
{ "soe", 5 },
{ "mask", 6 },
{ "eoe", 7 },
{ "stt", 8 },
{ "len", 0 },
{ "nse", 6 },
{ "cse", 6 },
};
// u32string because encoding agony
std::unordered_map<std::u32string, token_t> vocab = {
{U"<unk>",0},{U"<bos>",1},{U"</eos>",2},{U"<mask>",3},{U" ",4},{U"",4},{U"!",5},{U"\"",6},{U"(",7},{U"{",7},{U"[",7},{U")",8},{U"}",8},{U"]",8},{U",",9},{U"-",10},{U".",11},{U"1",211},{U"",10},{U"",6},{U"",81},{U"ˇ",6},{U"ˉ",12},{U"ˊ",79},{U"ˋ",80},{U"_",81},{U":",13},{U";",14},{U"?",15},{U"a",16},{U"ä",16},{U"ɒ",16},{U"b",17},{U"c",18},{U"d",19},{U"e",20},{U"f",21},{U"h",22},{U"i",23},{U"ĩ",23},{U"j",24},{U"k",25},{U"l",26},{U"m",27},{U"n",28},{U"ɴ",28},{U"ɲ",28},{U"o",29},{U"̞",29},{U"p",30},{U"ɸ",30},{U"q",31},{U"r",32},{U"ɽ",32},{U"ʁ",32},{U"s",33},{U"t",34},{U"u",35},{U"ø",35},{U"œ",35},{U"y",35},{U"ɣ",35},{U"ũ",35},{U"v",36},{U"w",37},{U"ʍ",37},{U"x",38},{U"z",39},{U"¡",40},{U"«",41},{U"»",42},{U"¿",43},{U"æ",44},{U"ç",45},{U"ð",46},{U"ŋ",47},{U"ɐ",48},{U"ɑ",49},{U"ɔ",50},{U"ɕ",51},{U"ə",52},{U"ɚ",53},{U"ɛ",54},{U"ɜ",55},{U"ɟ",56},{U"ɡ",57},{U"ɪ",58},{U"ɬ",59},{U"ɯ",60},{U"ɹ",61},{U"ɾ",62},{U"ʃ",63},{U"ʈ",64},{U"ʊ",65},{U"ʋ",66},{U"ʌ",67},{U"ʑ",68},{U"ʒ",69},{U"ʔ",70},{U"ʲ",71},{U"ˈ",72},{U"ˌ",73},{U"ː",74},{U"̃",75},{U"̩",76},{U"θ",77},{U"",78},{U"",82},{U"ˈɛ",83},{U"iː",84},{U"aɪ",85},{U"nd",86},{U"ˈɪ",87},{U"eɪ",88},{U"ˈæ",89},{U"ðə",90},{U"",91},{U"ɑː",92},{U"ˈeɪ",93},{U"ən",94},{U"uː",95},{U"ˈʌ",96},{U"ˈaɪ",97},{U"st",98},{U"ˈɔ",99},{U"ˈ",100},{U"ˈiː",101},{U"ˈɑː",102},{U"ænd",103},{U"ːɹ",104},{U"ɪŋ",105},{U"ɜː",106},{U"ɪn",107},{U"",108},{U"ʌv",109},{U"",110},{U"əl",111},{U"ˈuː",112},{U"",113},{U"ɪz",114},{U"ˈɜː",115},{U"ˌʌ",116},{U"æt",117},{U"",118},{U"ˈɔː",119},{U"ɪt",120},{U"ˈ",121},{U"ɚɹ",122},{U"ˈɛn",123},{U"",124},{U"li",125},{U"hiː",126},{U"ˌɛ",127},{U"wɪ",128},{U"wʌz",129},{U"ðæt",130},{U"juː",131},{U"oːɹ",132},{U"ðɪ",133},{U"sˈɛ",134},{U"ˌɪ",135},{U"ˈɑːɹ",136},{U"nt",137},{U"ˈʊ",138},{U"ənt",139},{U"hɪz",140},{U"ˌɑː",141},{U"",142},{U"ɔːɹ",143},{U"ˈɛɹ",144},{U"wɪð",145},{U"ᵻd",146},{U"ˈoːɹ",147},{U"",148},{U"ˈɔːl",149},{U"",150},{U"ʃən",151},{U"kt",152},{U"ˌoʊ",153},{U"ˈɔːɹ",154},{U"",155},{U"æz",156},{U"ˌʌt",157},{U"ʃiː",158},{U"ˈɛl",159},{U"ˌaʊ",160},{U"ˈʌn",161},{U"əs",162},{U"ː",163},{U"lˈaɪ",164},{U"ˈæn",165},{U"ˈɪɹ",166},{U"ʊd",167},{U"ɹᵻ",168},{U"ld",169},{U"bˌʌt",170},{U"ks",171},{U"nˈ",172},{U"hæd",173},{U"ɾɚ",174},{U"ɛɹ",175},{U"ˈɪŋ",176},{U"ɡɹ",177},{U"ɑː",178},{U"ɔn",179},{U"",180},{U"maɪ",181},{U"ːɹ",182},{U"ðɚ",183},{U"",184},{U"ðɛɹ",185},{U"ɑːt",186},{U"ˈʌm",187},{U"",188},{U"sˈiː",189},{U"ʌvðə",190},{U"mˈɪ",191},{U"hˈæ",192},{U"ˌɪm",193},{U"lˈeɪ",194},{U"ɪk",195},{U"sp",196},{U"ɪm",197},{U"ɐn",198},{U"ðeɪ",199},{U"lˈɪ",200},{U"ɾi",201},{U"lˈɛ",202},{U"",203},{U"",204},{U"lˈæ",205},{U"ˈɪl",206},{U"jˈuː",207},{U"ʌm",208},{U"mˌiː",209},{U"bᵻ",210},{U"wˈʌn",211},{U"ˌɪn",212},{U"ˈɪn",213},{U"ˈoʊn",214},{U"sˈɛd",215},{U"biː",216},{U"ˈɛd",217},{U"ˈaɪt",218},{U"baɪ",219},{U"fɹʌm",220},{U"ɪs",221},{U"ɚz",222},{U"ðɪs",223},{U"əns",224},{U"bəl",225},{U"ɪf",226},{U"ɪnðə",227},{U"əm",228},{U"ᵻz",229},{U"ˌuː",230},{U"wˈeɪ",231},{U"ft",232},{U"wiː",233},{U"stɹ",234},{U"lˈiː",235},{U"iːz",236},{U"pt",237},{U"",238},{U"ɚd",239},{U"ˌaɪ",240},{U"kw",241},{U"ˌɔn",242},{U"ˈaɪd",243},{U"ɪm",244},{U"ˈʌst",245},{U"ˈoʊld",246},{U"ts",247},{U"ˌɪ",248},{U"sˌoʊ",249},{U"dˈɪ",250},{U"ɑːɹ",251},{U"",252},{U"sˈeɪ",253},{U"ɾᵻd",254},{U"ɪ",255},
};
// cringe list of merges to later process and fill out the map for referencing merges
std::vector<merge_entry_t> vocab_merges = {
{U"ˈ", U"ɛ"},{U"i", U"ː"},{U"a", U"ɪ"},{U"n", U"d"},{U"ˈ", U"ɪ"},{U"e", U"ɪ"},{U"ˈ", U"æ"},{U"ð", U"ə"},{U"o", U"ʊ"},{U"ɑ", U"ː"},{U"ˈ", U"eɪ"},{U"ə", U"n"},{U"u", U"ː"},{U"ˈ", U"ʌ"},{U"ˈ", U"aɪ"},{U"s", U"t"},{U"ˈ", U"ɔ"},{U"ˈ", U""},{U"ˈ", U"iː"},{U"ˈ", U"ɑː"},{U"æ", U"nd"},{U"ː", U"ɹ"},{U"ɪ", U"ŋ"},{U"ɜ", U"ː"},{U"ɪ", U"n"},{U"t", U"ə"},{U"ʌ", U"v"},{U"a", U"ʊ"},{U"ə", U"l"},{U"ˈ", U"uː"},{U"t", U"ʃ"},{U"ɪ", U"z"},{U"ˈ", U"ɜː"},{U"ˌ", U"ʌ"},{U"æ", U"t"},{U"d", U"ʒ"},{U"ˈɔ", U"ː"},{U"ɪ", U"t"},{U"ˈ", U""},{U"ɚ", U"ɹ"},{U"ˈɛ", U"n"},{U"w", U"ʌ"},{U"l", U"i"},{U"h", U"iː"},{U"ˌ", U"ɛ"},{U"w", U"ɪ"},{U"", U"z"},{U"ð", U"æt"},{U"j", U"uː"},{U"o", U"ːɹ"},{U"ð", U"ɪ"},{U"s", U"ˈɛ"},{U"ˌ", U"ɪ"},{U"ˈɑː", U"ɹ"},{U"n", U"t"},{U"ˈ", U"ʊ"},{U"ən", U"t"},{U"h", U"ɪz"},{U"ˌ", U"ɑː"},{U"h", U"æ"},{U"ɔ", U"ːɹ"},{U"ˈɛ", U"ɹ"},{U"wɪ", U"ð"},{U"", U"d"},{U"ˈ", U"oːɹ"},{U"p", U"ɹ"},{U"ˈɔː", U"l"},{U"m", U"ˌ"},{U"ʃ", U"ən"},{U"k", U"t"},{U"ˌ", U""},{U"ˈɔ", U"ːɹ"},{U"f", U"ɹ"},{U"æ", U"z"},{U"ˌʌ", U"t"},{U"ʃ", U"iː"},{U"ˈɛ", U"l"},{U"ˌ", U""},{U"ˈʌ", U"n"},{U"ə", U"s"},{U"h", U"ɜː"},{U"l", U"ˈaɪ"},{U"ˈæ", U"n"},{U"ˈɪ", U"ɹ"},{U"ʊ", U"d"},{U"ɹ", U""},{U"l", U"d"},{U"b", U"ˌʌt"},{U"k", U"s"},{U"n", U"ˈ"},{U"", U"d"},{U"ɾ", U"ɚ"},{U"ɛ", U"ɹ"},{U"ˈɪ", U"ŋ"},{U"ɡ", U"ɹ"},{U"n", U"ˌɑː"},{U"ɔ", U"n"},{U"v", U"ɚ"},{U"m", U"aɪ"},{U"f", U"ɔːɹ"},{U"ð", U"ɚ"},{U"t", U"ʊ"},{U"ð", U"ɛɹ"},{U"ɑː", U"t"},{U"ˈʌ", U"m"},{U"t", U"ɹ"},{U"s", U"ˈiː"},{U"ʌv", U"ðə"},{U"m", U"ˈɪ"},{U"h", U"ˈæ"},{U"ˌɪ", U"m"},{U"l", U"ˈeɪ"},{U"ɪ", U"k"},{U"s", U"p"},{U"h", U"ˌɪm"},{U"ɐ", U"n"},{U"ð", U"eɪ"},{U"l", U"ˈɪ"},{U"ɾ", U"i"},{U"l", U"ˈɛ"},{U"b", U"ɹ"},{U"k", U"ɹ"},{U"l", U"ˈæ"},{U"ˈɪ", U"l"},{U"j", U"ˈuː"},{U"ʌ", U"m"},{U"", U"iː"},{U"b", U""},{U"w", U"ˈʌn"},{U"ˌ", U"ɪn"},{U"ˈɪ", U"n"},{U"ˈ", U"n"},{U"sˈɛ", U"d"},{U"b", U"iː"},{U"ˈɛ", U"d"},{U"ˈaɪ", U"t"},{U"b", U"aɪ"},{U"", U"ʌm"},{U"ɪ", U"s"},{U"ɚ", U"z"},{U"ðɪ", U"s"},{U"ən", U"s"},{U"b", U"əl"},{U"ɪ", U"f"},{U"ɪn", U"ðə"},{U"ə", U"m"},{U"", U"z"},{U"ˌ", U"uː"},{U"w", U"ˈeɪ"},{U"f", U"t"},{U"w", U"iː"},{U"st", U"ɹ"},{U"l", U"ˈiː"},{U"iː", U"z"},{U"p", U"t"},{U"j", U"ʊ"},{U"ɚ", U"d"},{U"ˌ", U"aɪ"},{U"k", U"w"},{U"ˌ", U"ɔn"},{U"ˈaɪ", U"d"},{U"ɪ", U"m"},{U"ˈʌ", U"st"},{U"ˈ", U"ld"},{U"t", U"s"},{U"ˌɪ", U""},{U"s", U"ˌoʊ"},{U"d", U"ˈɪ"},{U"ɑː", U"ɹ"},{U"h", U"ɐ"},{U"s", U"ˈeɪ"},{U"ɾ", U"ᵻd"},{U"w", U"ˌɪ"},
};
// merge map to reference when tokenizing text
std::unordered_map<std::string, merge_entry_t> vocab_merge_map = {};
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
std::vector<float> read_2d_tensor( struct ggml_tensor* tensor ) {
size_t size = tensor->ne[0] * tensor->ne[1];
std::vector<float> res( size );
@ -65,7 +90,7 @@ std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor ) {
return res;
}
/*
ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
ggml_tensor* view_2d_tensor( struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
// to-do: implement other dim
if ( start < 0 ) start = tensor->ne[1] + start;
if ( end < 0 ) end = tensor->ne[1] + end;
@ -86,7 +111,7 @@ ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_tensor* tensor, int32_t star
return res;
}
*/
ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_context* ctx, struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
ggml_tensor* view_2d_tensor( struct ggml_context* ctx, struct ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim ) {
// to-do: implement other dim
if ( start < 0 ) start = tensor->ne[1] + start;
if ( end < 0 ) end = tensor->ne[1] + end;
@ -96,7 +121,7 @@ ggml_tensor* VALL_E_API view_2d_tensor( struct ggml_context* ctx, struct ggml_te
return res;
}
void VALL_E_API print_tokens( const std::vector<token_t>& tokens, const std::string& prefix ) {
void print_tokens( const std::vector<token_t>& tokens, const std::string& prefix ) {
printf("%s[", prefix.c_str());
for ( auto i = 0; i < tokens.size(); ++i ) {
printf("%i%s", tokens[i], i + 1 < tokens.size() ? ", " : "");
@ -104,18 +129,18 @@ void VALL_E_API print_tokens( const std::vector<token_t>& tokens, const std::str
printf("]\n");
}
const io_t& VALL_E_API vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) {
const io_t& vall_e_inputs_map_get( io_map_t& io_map, const std::string& name ) {
return io_map.io[name];
}
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& io_map, const std::string& name ) {
const float* vall_e_inputs_map_get_embeddings_p( io_map_t& io_map, const std::string& name ) {
return io_map.io[name].embds.data();
}
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& io_map, const std::string& name ) {
int32_t vall_e_inputs_map_get_classifier_idx( io_map_t& io_map, const std::string& name ) {
return io_map.io[name].head_idx;
}
void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
void vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
auto n_embd = llama_n_embd( model );
auto n_vocab = llama_n_vocab( model );
@ -187,7 +212,7 @@ void VALL_E_API vall_e_inputs_map_init( io_map_t& io_map, llama_model* model ) {
}
// maps embeddings easily
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds ) {
std::vector<std::vector<float>> map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds ) {
std::vector<std::vector<float>> embedded( tokens.size() );
for ( auto i = 0; i < tokens.size(); ++i ) {
embedded[i].insert( embedded[i].end(), embds + (tokens[i] * n_embd), embds + ((tokens[i]+1) * n_embd) );
@ -197,7 +222,7 @@ std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<tok
// handles adding either a token OR the embedding of that token into the batch
// this really, really helps avoid needing to abuse the tokenizer
void VALL_E_API batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids ) {
void batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids ) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
// insert raw embedding instead
@ -219,7 +244,7 @@ void VALL_E_API batch_add( llama_batch& batch, token_t id, int n_embd, const flo
batch.n_tokens++;
}
// reads a waveform from disk
std::vector<float> VALL_E_API read_audio_from_disk( const std::string& path ) {
std::vector<float> read_audio_from_disk( const std::string& path ) {
std::vector<float> res;
uint32_t channels;
@ -248,7 +273,7 @@ std::vector<float> VALL_E_API read_audio_from_disk( const std::string& path ) {
return res;
}
// writes a waveform to disk
void VALL_E_API write_audio_to_disk( const std::vector<float>& wavform, const std::string& path ) {
void write_audio_to_disk( const std::vector<float>& wavform, const std::string& path ) {
drwav_data_format format;
format.bitsPerSample = 32;
format.sampleRate = 24000;
@ -264,7 +289,7 @@ void VALL_E_API write_audio_to_disk( const std::vector<float>& wavform, const st
fprintf(stderr, "%s: Number of frames written = %lld.\n", __func__, frames);
}
// reads a waveform from disk then encodes it
std::vector<std::vector<int32_t>> VALL_E_API encode_audio( struct encodec_context* ectx, const std::vector<float>& wavform ) {
std::vector<std::vector<int32_t>> encode_audio( struct encodec_context* ectx, const std::vector<float>& wavform ) {
// compress audio
if (!encodec_compress_audio(ectx, wavform.data(), wavform.size(), 1)) {
fprintf(stderr, "%s: error during compression \n", __func__);
@ -285,7 +310,7 @@ std::vector<std::vector<int32_t>> VALL_E_API encode_audio( struct encodec_contex
return res;
}
// decodes a 2D codebook into a waveform
std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes ) {
std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes ) {
int n_codebooks = codes.size();
int n_frames = codes[0].size();
@ -310,7 +335,7 @@ std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const
}
// sums embeddings over a 2D "tensor"
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std::vector<token_t>>& inputs, int n_embd, int rvq_l, const float** embds, int mode ) {
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<token_t>>& inputs, int n_embd, int rvq_l, const float** embds, int mode ) {
auto n_tokens = inputs[0].size();
std::vector<std::vector<float>> res( n_tokens, std::vector<float>( n_embd, 0.0 ) );
@ -336,7 +361,7 @@ std::vector<std::vector<float>> VALL_E_API sum_embeddings( const std::vector<std
return res;
}
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
std::vector<float> soft_max( int n_logits, const float* logits ) {
std::vector<float> res( n_logits, 0.0f );
std::vector<float> expd( n_logits, 0.0f );
float denom = 0.0f;
@ -353,7 +378,7 @@ std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits ) {
return res;
}
std::vector<float> VALL_E_API log_soft_max( int n_logits, const float* logits ) {
std::vector<float> log_soft_max( int n_logits, const float* logits ) {
std::vector<float> res( n_logits, 0.0f );
float denom = 0.0f;
@ -368,7 +393,7 @@ std::vector<float> VALL_E_API log_soft_max( int n_logits, const float* logits )
return res;
}
void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_t& io_map, int mode ) {
void fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_t& io_map, int mode ) {
// keeps track of the position for each sequence
size_t pos = 0;
auto n_embd = io_map.n_embd;
@ -402,18 +427,25 @@ void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_
vall_e_inputs_map_get_embeddings_p(io_map, "resps|NAR:0:0"),
};
token_t lang_token = lang_map[inputs.lang];
token_t task_token = task_map[inputs.task];
// insert text tokens
for ( auto& id : inputs.phn ) batch_add( batch, id, n_embd, text_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
pos = 0;
// insert lang token
batch_add( batch, inputs.lang, n_embd, lang_embds, pos++, false );
batch_add( batch, lang_token, n_embd, lang_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
pos = 0;
// insert rvq level token
batch_add( batch, inputs.rvq_l, n_embd, rvq_l_embds, pos++, false );
batch_add( batch, 0, n_embd, sep_embds, pos++, false );
pos = 0;
// input task token if needed
if ( task_token > 0 ) {
batch_add( batch, task_token, n_embd, task_embds, pos++, false );
}
// insert prom tokens
auto summed_proms_embds = sum_embeddings( inputs.prom, n_embd, inputs.rvq_l, prom_embds );
for ( auto i = 0; i < summed_proms_embds.size(); ++i ) {
@ -439,13 +471,13 @@ void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& inputs, io_map_
}
// generation code, should handle all modalities easily
std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_tokens, int mode, bool verbose ) {
std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_tokens, int mode, bool verbose ) {
bool causal = true; // sample autoregressively or not
int n_outputs = 0; // number of output tokens to expect
// create batch (targetting embeddings instead of tokens)
llama_batch batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map.n_embd, ctx->params.ctx_size );
fill_batch( batch, inputs, ctx->io_map, mode );
llama_batch batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map->n_embd, ctx->params.ctx_size );
fill_batch( batch, inputs, *ctx->io_map, mode );
// determine how many outputs we need
for ( auto i = 0; i < batch.n_tokens; ++i ) {
@ -485,7 +517,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
embd_name = "resps|NAR:0:0";
}
auto& io = vall_e_inputs_map_get(ctx->io_map, embd_name);
auto& io = vall_e_inputs_map_get(*ctx->io_map, embd_name);
const float* embds = io.embds.data();
int32_t n_embd = io.n_embd;
@ -535,7 +567,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
// store token
output_tokens.emplace_back(t);
// update batch with token
batch_add( batch, t, ctx->io_map.n_embd, embds, output_tokens.size(), true );
batch_add( batch, t, ctx->io_map->n_embd, embds, output_tokens.size(), true );
if ( verbose ) print_tokens( output_tokens );
}
@ -560,7 +592,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
null_input.phn = {1, 2}; // <bos></eos>
null_input.resp.resize(1);
llama_batch null_batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map.n_embd, ctx->params.ctx_size );
llama_batch null_batch = llama_batch_init( ctx->params.ctx_size, ctx->io_map->n_embd, ctx->params.ctx_size );
// token scores to reference for masking
std::vector<float> scores(n_outputs, 1.0);
@ -607,11 +639,11 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
// to-do: only update the embeddings instead
batch.n_tokens = 0;
inputs.resp[0] = output_tokens;
fill_batch( batch, inputs, ctx->io_map, mode );
fill_batch( batch, inputs, *ctx->io_map, mode );
// update null batch
null_input.resp[0] = output_tokens;
null_batch.n_tokens = 0;
fill_batch( null_batch, inputs, ctx->io_map, mode );
fill_batch( null_batch, inputs, *ctx->io_map, mode );
// cfg decode
if ( llama_decode(ctx->llama.ctx, null_batch) ) {
@ -717,7 +749,7 @@ std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t
return output_tokens;
}
std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) {
std::vector<token_t> phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language ) {
std::vector<token_t> tokens;
// phonemize text
@ -767,6 +799,7 @@ std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::str
}
tokens.emplace_back(2);
if ( ctx->params.verbose ) print_tokens( tokens, "Phonemes: " );
/*
// to-do: fix terminate called after throwing an instance of 'std::out_of_range'
@ -782,7 +815,7 @@ std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::str
return tokens;
}
void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
void vall_e_print_usage( char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
@ -803,6 +836,8 @@ void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params
fprintf(stderr, " Input text prompt (default: %s)\n", args.text.c_str());
fprintf(stderr, " -l TEXT, --language TEXT\n");
fprintf(stderr, " Language for input text / output response (default: %s)\n", args.language.c_str());
fprintf(stderr, " -ts TASK, --task TASK\n");
fprintf(stderr, " Inferencing task (default: %s, accepts ['tts', 'stt', 'ns', 'sr'])\n", args.task);
fprintf(stderr, " -mode MODE, --modality MODE\n");
fprintf(stderr, " Modality for inferencing (default: %s, accepts ['ar+nar', 'nar-len'])\n", args.modality == MODALITY_NAR_LEN ? "nar-len" : "ar+nar");
fprintf(stderr, " -ms N, --max-steps N\n");
@ -815,7 +850,7 @@ void VALL_E_API vall_e_print_usage( char** argv, vall_e_context_params_t& params
fprintf(stderr, " Output audio wav (default: %s)\n", args.output_path.c_str());
fprintf(stderr, "\n");
}
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
bool vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ) {
for ( int i = 1; i < argc; i++ ) {
std::string arg = argv[i];
@ -835,6 +870,8 @@ bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_
args.text = argv[++i];
} else if (arg == "-l" || arg == "--language") {
args.language = argv[++i];
} else if (arg == "-ts" || arg == "--task") {
args.task = argv[++i];
} else if (arg == "-mode" || arg == "--modality") {
args.modality = argv[++i] == "ar+nar" ? MODALITY_AR_NAR : MODALITY_NAR_LEN;
} else if (arg == "-ms" || arg == "--max-steps") {
@ -859,8 +896,9 @@ bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_
return true;
}
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params ) {
vall_e_context_t* vall_e_load( const vall_e_context_params_t& params ) {
vall_e_context_t* ctx = new vall_e_context_t();
ctx->io_map = new io_map_t();
ctx->params = params;
// setup ggml
@ -905,7 +943,7 @@ vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params
espeak_Initialize(AUDIO_OUTPUT_SYNCHRONOUS, 0, NULL, 0);
// setup vall_e.cpp
vall_e_inputs_map_init( ctx->io_map, ctx->llama.model );
vall_e_inputs_map_init( *ctx->io_map, ctx->llama.model );
// setup vocab things
for ( auto& entry : vocab_merges ) {
@ -921,17 +959,15 @@ vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params
return ctx;
}
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& language ) {
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& language, const std::string& task ) {
// to-do: set members in initializer rather than in post
vall_e_inputs_t inputs;
inputs.task = task;
inputs.rvq_l = 0;
inputs.phn = phonemize( ctx, text, language );
inputs.prom = encode_audio( ctx->encodec.ctx, read_audio_from_disk( prompt_path ) );
if ( language == "en" ) inputs.lang = 0;
else if ( language == "ja" ) inputs.lang = 1;
else if ( language == "de" ) inputs.lang = 2;
else if ( language == "fr" ) inputs.lang = 3;
else if ( language == "zh" ) inputs.lang = 4;
else if ( language == "ko" ) inputs.lang = 5;
inputs.lang = language;
return inputs;
}
@ -943,17 +979,20 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
// inference len
int len = 0;
if ( !len ) {
auto task = inputs.task;
inputs.task = "len";
output_tokens = generate( ctx, inputs, 5, INFERENCE_MODE_LEN, ctx->params.verbose );
{
// to-do: one liner this
int digit = 1;
for (auto it = output_tokens.rbegin(); it < output_tokens.rend(); ++it) {
len += (*it) * digit;
digit *= 10;
}
}
// cap for now
// cap duration
if ( len <= 0 || len > max_duration ) len = max_duration;
inputs.task = task;
}
// fill with mask tokens
inputs.resp.resize(1);
@ -962,7 +1001,6 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
}
// inference NAR-len 0
inputs.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
inputs.rvq_l = l;
output_tokens = generate( ctx, inputs, max_steps, l == 0 ? INFERENCE_MODE_NAR_DEMASK : INFERENCE_MODE_NAR, ctx->params.verbose );
@ -971,7 +1009,6 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
}
// AR+NAR
} else if ( modality == MODALITY_AR_NAR ){
inputs.task = "tts";
for ( auto l = 0; l < 8; ++l ) {
inputs.rvq_l = l;
output_tokens = generate( ctx, inputs, l == 0 ? max_duration : 1, l == 0 ? INFERENCE_MODE_AR : INFERENCE_MODE_NAR, ctx->params.verbose );
@ -981,18 +1018,17 @@ vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& in
return inputs.resp;
}
void VALL_E_API vall_e_free( vall_e_context_t* ctx ) {
void vall_e_free( vall_e_context_t* ctx ) {
espeak_Terminate();
encodec_free(ctx->encodec.ctx);
llama_free(ctx->llama.ctx);
llama_free_model(ctx->llama.model);
ggml_free(ctx->io_map.ctx);
ggml_free(ctx->io_map->ctx);
delete ctx->io_map;
delete ctx;
}
int main( int argc, char** argv ) {
// to-do: parse CLI args
vall_e_context_params_t params;
vall_e_args_t args;

View File

@ -5,34 +5,100 @@
#include <vector>
#include <unordered_map>
// external deps
#include <llama.h>
#include <encodec.h>
#include <dr_wav.h>
#include <espeak-ng/speak_lib.h>
// to-do: copy over import/export stuff from engine project (because I don't remember how I set it up in <uf/config.h>)
#define VALL_E_API
// handles defining platform specific macros and import/export decorators (copied from my engine's uf/config.h)
#if defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__)
// Windows
#define VALL_E_ENV "Windows"
#define VALL_E_ENV_WINDOWS 1
#define VALL_E_ENV_HEADER "windows.h"
#if defined(__CYGWIN__)
#define to_string(var) string(var)
#endif
#ifndef _WIN32_WINNT
#define _WIN32_WINNT 0x0600
#endif
#ifndef WINVER
#define WINVER 0x0600
#endif
#define VALL_E_IO_ROOT "./data/"
#elif defined(linux) || defined(__linux)
// Linux
#define VALL_E_ENV "Linux"
#define VALL_E_ENV_LINUX 1
#define VALL_E_ENV_HEADER "linux.h"
#define VALL_E_IO_ROOT "./data/"
#elif defined(__APPLE__) || defined(MACOSX) || defined(macintosh) || defined(Macintosh)
// MacOS
#define VALL_E_ENV "OSX"
#define VALL_E_ENV_OSX 1
#define VALL_E_ENV_HEADER "osx.h"
#define VALL_E_IO_ROOT "./data/"
#elif defined(__FreeBSD__) || defined(__FreeBSD_kernel__)
// FreeBSD
#define VALL_E_ENV "FreeBSD"
#define VALL_E_ENV_FREEBSD 1
#define VALL_E_ENV_HEADER "freebsd.h"
#define VALL_E_IO_ROOT "./data/"
#elif defined(__sh__)
// Dreamcast
#define VALL_E_ENV "Dreamcast"
#define VALL_E_ENV_DREAMCAST 1
#define VALL_E_ENV_HEADER "dreamcast.h"
#include VALL_E_ENV_HEADER
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
#define _arch_dreamcast
#if !LLAMA_CPP_EXTENDED
#include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
#define VALL_E_IO_ROOT "/cd/"
#else
// Unsupported system
#define VALL_E_ENV "Unknown"
#define VALL_E_ENV_UNKNOWN 1
#define VALL_E_ENV_HEADER "unknown.h"
#warning Using "unknown"
#error No support
#endif
// to-do: clean up spaghetti enums
const int EMBEDDING_MODE_PROM = 0;
const int EMBEDDING_MODE_RESP_AR_NAR = 1;
const int EMBEDDING_MODE_RESP_NAR_LEN = 2;
#if !defined(VALL_E_STATIC)
#if defined(VALL_E_ENV_WINDOWS)
// Windows compilers need specific (and different) keywords for export and import
#define VALL_E_API_EXPORT __declspec(dllexport)
#define VALL_E_API_IMPORT __declspec(dllimport)
// For Visual C++ compilers, we also need to turn off this annoying C4251 warning
#ifdef _MSC_VER
#pragma warning(disable : 4251)
#endif
#else // Linux, FreeBSD, Mac OS X
#if __GNUC__ >= 4
// GCC 4 has special keywords for showing/hidding symbols,
// the same keyword is used for both importing and exporting
#define VALL_E_API_EXPORT __attribute__ ((__visibility__ ("default")))
#define VALL_E_API_IMPORT __attribute__ ((__visibility__ ("default")))
#else
// GCC < 4 has no mechanism to explicitely hide symbols, everything's exported
#define VALL_E_API_EXPORT
#define VALL_E_API_IMPORT
#endif
#endif
#else
// Static build doesn't need import/export macros
#define VALL_E_API_EXPORT
#define VALL_E_API_IMPORT
#endif
const int INFERENCE_MODE_LEN = 0;
const int INFERENCE_MODE_AR = 1;
const int INFERENCE_MODE_NAR_DEMASK = 2;
const int INFERENCE_MODE_NAR = 3;
#ifdef VALL_E_EXPORTS
#define VALL_E_API VALL_E_API_EXPORT
#else
#define VALL_E_API VALL_E_API_IMPORT
#endif
const int MODALITY_AR_NAR = 0;
const int MODALITY_NAR_LEN = 1;
typedef llama_token token_t;
typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t;
const int ENCODEC_FRAMES_PER_SECOND = 75;
const int MAX_DURATION = ENCODEC_FRAMES_PER_SECOND * 12;
@ -40,52 +106,16 @@ const int CTX_SIZE = 2048;
const int N_THREADS = 8;
const int N_GPU_LAYERS = 99;
typedef llama_token token_t;
typedef std::vector<std::vector<token_t>> vall_e_audio_codes_t;
const int MODALITY_AR_NAR = 0;
const int MODALITY_NAR_LEN = 1;
// stores embeddings + metadata for an embedding range
struct io_t {
std::string name;
uint32_t start;
uint32_t end;
int32_t head_idx = -1;
int32_t n_embd = 0;
int32_t n_vocab = 0;
std::vector<float> embds = {};
ggml_tensor* head = NULL;
};
// stores the mappings between tokens, input embeddings, and output heads
struct io_map_t {
// model's original params
int32_t n_embd = 0;
int32_t n_vocab = 0;
// mapping
std::unordered_map<std::string, io_t> io = {};
// context to store slices
ggml_context* ctx = NULL;
};
struct score_t {
int32_t idx;
float value;
bool operator<( const score_t& that ) const { return this->value < that.value; }
};
struct merge_entry_t {
std::u32string pre;
std::u32string post;
std::u32string resolved;
token_t pre_token;
token_t post_token;
token_t resolved_token;
};
// forward declarations
struct io_map_t;
struct llama_model;
struct llama_context;
struct encodec_context;
// model-specific parameters
struct vall_e_context_params_t {
std::string model_path = "./data/vall_e.gguf";
std::string encodec_path = "./data/encodec.bin";
@ -94,20 +124,22 @@ struct vall_e_context_params_t {
int32_t ctx_size = CTX_SIZE;
bool verbose = false;
};
// inference-specific arguments
struct vall_e_args_t {
std::string text = "Hello world.";
std::string prompt_path = "./data/prom.wav";
std::string output_path = "./data/resp.wav";
std::string language = "en";
std::string task = "tts";
int modality = MODALITY_NAR_LEN;
int max_steps = 30;
int max_duration = MAX_DURATION;
};
// stores everything needed for vall_e.cpp
// stores everything needed for vall_e.cpp at runtime
struct vall_e_context_t {
vall_e_context_params_t params;
io_map_t io_map;
io_map_t* io_map = NULL; // pointer for reasons
struct {
llama_model* model = NULL;
@ -121,49 +153,26 @@ struct vall_e_context_t {
// stores the raw inputs to be fed
struct vall_e_inputs_t {
std::string task = "tts";
std::string lang = "en";
token_t rvq_l = 0;
std::vector<token_t> phn = {};
token_t lang = 0;
token_t rvq_l = 0;
vall_e_audio_codes_t prom = {};
vall_e_audio_codes_t resp = {};
};
// helper tensor functions
std::vector<float> VALL_E_API read_2d_tensor( struct ggml_tensor* tensor );
//ggml_tensor* VALL_E_API view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket
ggml_tensor* VALL_E_API view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 );
void VALL_E_API print_tokens( const std::vector<token_t>& tokens, const std::string& prefix = "Tokens: " );
std::vector<std::vector<float>> VALL_E_API map_embeddings( const std::vector<token_t>& tokens, int n_embd, const float* embds );
std::vector<std::vector<float>> VALL_E_API sum_embeddings( const vall_e_audio_codes_t& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
std::vector<float> VALL_E_API soft_max( int n_logits, const float* logits );
// batch and inferencing
void VALL_E_API batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector<llama_seq_id> & seq_ids = {0} );
void VALL_E_API fill_batch( llama_batch& batch, vall_e_inputs_t& input, io_map_t& inputs_map, int mode );
std::vector<token_t> VALL_E_API generate( vall_e_context_t* ctx, vall_e_inputs_t& input, int max_tokens, int mode, bool verbose = true );
//
std::vector<token_t> VALL_E_API phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language = "auto" );
// encodec helpers
std::vector<float> VALL_E_API read_audio_from_disk( const std::string& path );
void VALL_E_API write_audio_to_disk( const std::vector<float>& waveform, const std::string& path );
VALL_E_API std::vector<float> read_audio_from_disk( const std::string& path );
VALL_E_API void write_audio_to_disk( const std::vector<float>& waveform, const std::string& path );
std::vector<std::vector<int32_t>> VALL_E_API encode_audio( struct encodec_context* ectx, const std::vector<float>& waveform );
std::vector<float> VALL_E_API decode_audio( struct encodec_context* ectx, const std::vector<std::vector<int32_t>>& codes_2d );
// model-accessing helpers
const io_t& VALL_E_API vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name );
const float* VALL_E_API vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name );
int32_t VALL_E_API vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name );
void VALL_E_API vall_e_inputs_map_init( io_map_t&, llama_model* model );
VALL_E_API std::vector<std::vector<int32_t>> encode_audio( struct encodec_context* ectx, const std::vector<float>& waveform );
VALL_E_API std::vector<float> decode_audio( struct encodec_context* ectx, const vall_e_audio_codes_t& codes_2d );
// context management
void VALL_E_API vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args );
bool VALL_E_API vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args );
vall_e_context_t* VALL_E_API vall_e_load( const vall_e_context_params_t& params );
vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang );
vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN );
void VALL_E_API vall_e_free( vall_e_context_t* ctx );
VALL_E_API void vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args );
VALL_E_API bool vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args );
VALL_E_API vall_e_context_t* vall_e_load( const vall_e_context_params_t& params );
VALL_E_API vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang = "auto", const std::string& task = "tts" );
VALL_E_API vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int modality = MODALITY_NAR_LEN );
VALL_E_API void vall_e_free( vall_e_context_t* ctx );

View File

@ -309,7 +309,7 @@ def fold_inputs(
return 0
if isinstance(prom, str):
task = get_task_symmap()[f'<{input}>']
task = get_task_symmap()[input]
seq = torch.tensor([task_start + task], device=device, dtype=dtype)
input_ids[i].append( seq )
@ -664,19 +664,19 @@ def get_tone_symmap():
def get_task_symmap():
return {
"<tts>": 0,
"<tts-c>": 1,
"<ns>": 2,
"<sr>": 3,
"<tse>": 4,
"<soe>": 5,
"<mask>": 6,
"<eoe>": 7,
"<stt>": 8,
"tts": 0,
"tts-c": 1,
"ns": 2,
"sr": 3,
"tse": 4,
"soe": 5,
"mask": 6,
"eoe": 7,
"stt": 8,
"<len>": 0, # fake
"<nse>": 6, # fake
"<cse>": 6, # fake
"len": 0, # fake
"nse": 6, # fake
"cse": 6, # fake
}
def _replace_file_extension(path, suffix):
@ -1330,7 +1330,7 @@ class Dataset(_Dataset):
task = random.choice(self.tasks)
if f'<{task}>' not in self.task_symmap:
if task not in self.task_symmap:
raise Exception(f'Task not defined: {task}')
# Base TTS (<text><prompt> => <resp>)

View File

@ -40,6 +40,20 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
}
}
# cleanup duplicate IDs because convert_hf_to_gguf.py does not like this
# get unique tokens
for k, v in tokenizer["model"]["vocab"].items():
if k not in tokenizer_vocab:
tokenizer_vocab[k] = v
# override if its in a merge
for k, v in tokenizer["model"]["vocab"].items():
for m in tokenizer["model"]["merges"]:
if k in m:
tokenizer_vocab[k] = v
break
tokenizer["model"]["vocab"] = {}
lang_map = [
"en",
"ja",
@ -165,7 +179,7 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
model_dict['lm_head.weight'] = classifier_dict['weight']
if classifier_bias:
model_dict['lm_head.bias'] = classifier_dict['bias']
# write files in an HF compatible way
out_dir = cfg.rel_path / "hf"
out_dir.mkdir(parents=True, exist_ok=True)
@ -217,131 +231,6 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
return state_dict
# stitches embeddings into one embedding & classifier => lm_head, for use in a HF compatible weight
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
@torch.no_grad()
def convert_to_hf_custom( state_dict, config = None, save_path = None ):
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0]
n_len_tokens = 11
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
# the new tokenizer to use
tokenizer = {}
tokenizer_vocab = {}
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
if not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
if tokenizer_path.exists():
tokenizer = json_read( tokenizer_path )
else:
tokenizer = {
"model": {
"vocab": get_phone_symmap()
}
}
lang_map = [
"en",
"ja",
"de",
"fr",
"zh",
"ko",
]
task_map = [
"tts",
"tts-c",
"ns",
"sr",
"tse",
"soe",
"mask",
"eoe",
"stt",
]
model_dict = {}
# filter out the underlying model weights and extract them
for k in state_dict['module'].keys():
if not k.startswith('model.'):
continue
model_dict[k] = state_dict['module'][k].clone()
# cringe
for l in range(11):
model_dict[f'classifiers.{l}.weight'] = state_dict['module'][f'classifiers.proj.{l}.weight']
for l in range(8):
model_dict[f"embeddings.proms.{l}.weight"] = state_dict['module'][f"proms_emb.embeddings.{l}.weight"]
for l in range(9):
model_dict[f"embeddings.resps.{l}.weight"] = state_dict['module'][f"resps_emb.embeddings.{l}.weight"]
model_dict["embeddings.aux.0.weight"] = state_dict['module']["text_emb.weight"]
model_dict["embeddings.aux.1.weight"] = state_dict['module']["rvq_l_emb.weight"]
model_dict["embeddings.aux.2.weight"] = state_dict['module']["langs_emb.weight"]
model_dict["embeddings.aux.3.weight"] = state_dict['module']["tasks_emb.weight"]
model_dict["embeddings.aux.4.weight"] = state_dict['module']["len_emb.weight"]
model_dict["embeddings.aux.5.weight"] = state_dict['module']["tones_emb.weight"]
model_dict["embeddings.aux.6.weight"] = state_dict['module']["sep"].unsqueeze(0)
# write files in an HF compatible way
out_dir = cfg.rel_path / "hf"
out_dir.mkdir(parents=True, exist_ok=True)
# write weights
torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" )
# write tokenizer.json
tokenizer['model']['vocab'] |= tokenizer_vocab
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
# write tokenizer_config.json
json_write({
"added_tokens": tokenizer['added_tokens'],
"bos_token": "<bos>",
"eos_token": "</eos>",
"clean_up_tokenization_spaces": True,
"model_input_names": [
"input_ids",
"attention_mask"
],
"tokenizer_class": "PreTrainedTokenizerFast"
}, out_dir / "tokenizer_config.json", pretty=True)
# write config.json
json_write({
"architectures": [
"ValleLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": model_dim,
"initializer_range": 0.02,
"intermediate_size": model_dim * 4,
"max_position_embeddings": 75 * 60 * 5,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 12,
"num_key_value_heads": 16,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0",
"use_cache": False,
"vocab_size": 256
}, out_dir / "config.json", pretty=True )
return state_dict
# yanks a LoRA from the training checkpoint
def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:
@ -412,8 +301,7 @@ def moe_ify( state_dict, config = cfg.model, save_path = None, dtype = None ):
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--hf-llama", action='store_true', default=None) # convert to HF-style llama model
parser.add_argument("--hf", action='store_true', default=None) # convert to HF LLaMA
parser.add_argument("--export-lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
parser.add_argument("--moe-ify", action='store_true', default=None) # splits classifier heads
@ -441,9 +329,7 @@ def main():
engines = load_engines(training=False) # to ignore loading optimizer state
callback = None
if args.hf_llama:
callback = convert_to_hf_llama
elif args.hf:
if args.hf:
callback = convert_to_hf_custom
elif args.export_lora:
callback = extract_lora

View File

@ -59,40 +59,24 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
def get_model(config, training=True, **model_kwargs):
from .ar_nar import AR_NAR # import here because reasons
name = config.name
if config.experimental.hf:
from .experimental import Model as Experimental
model = Experimental(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_layers=config.layers,
n_heads=config.heads,
p_dropout=config.dropout,
config = config,
**model_kwargs
)
else:
from .ar_nar import AR_NAR
model = AR_NAR(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
p_dropout=config.dropout,
l_padding = config.input_alignment,
training = training,
config = config,
**model_kwargs
)
model = AR_NAR(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
p_dropout=config.dropout,
l_padding = config.input_alignment,
training = training,
config = config,
**model_kwargs
)
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

View File

@ -694,7 +694,6 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
)
print( config )
self.model = LlamaClass(config)
# replace with desired attention
@ -951,7 +950,7 @@ class Base(nn.Module):
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
if task_type in get_task_symmap() and task_type not in special_tasks:
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
@ -1092,7 +1091,7 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding( input, quant_level ):
if isinstance(input, str):
return self.tasks_emb( torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16) )
return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) )
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
@ -1348,14 +1347,8 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
"""
if isinstance(input, str):
return torch.tensor( [ self.ignore_index ], device=device, dtype=torch.int16)
return torch.tensor( [ self.ignore_index ] * input.shape[0], device=device, dtype=torch.int16)
"""
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):

View File

@ -1,574 +0,0 @@
"""
This is an experiment to:
* entertain a thought to try and abide by HF's transformers API (to benefit from caching better)
* conform to a single embedding (instead of a bunch of them) by folding/unfolding inputs
* stop trying to make a mixed AR+NAR model work since it seems lobotomized if I keep trying to enforce both recurrent and parallel inferencing (despite a penalty cost)
+ I will not cave and go with codebook patterns, not yet.
"""
from ..config import cfg
from ..data import fold_inputs, unfold_outputs
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
import random
import math
import logging
_logger = logging.getLogger(__name__)
from einops import rearrange
from tqdm import trange
from .arch import *
if cfg.model.arch_type not in AVAILABLE_ARCHES:
raise ValueError(f"Requesting arch `{cfg.model.arch_type}` but not available")
if cfg.model.arch_type in ["mamba","mamba2"]:
LlmArchClass = MambaLMHeadModel
elif cfg.model.arch_type == "llama":
LlmArchClass = LlamaForCausalLM
elif cfg.model.arch_type == "retnet":
LlmArchClass = RetNetForCausalLM
else:
raise ValueError(f"Requesting arch `{cfg.model.arch_type}` but not available")
class Model(LlmArchClass):
def __init__(
self,
n_text_tokens = 256,
n_audio_tokens = 1024,
d_model=1024,
n_layers=12,
n_heads=16,
p_dropout=0.1,
config = cfg.model,
):
self.hyper_config = config
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
if hf_attention == "auto":
if AVAILABLE_ATTENTIONS:
hf_attention = AVAILABLE_ATTENTIONS[0]
else:
hf_attention = "eager"
if hf_attention == "xformers":
hf_attention = "mem_efficient"
text_start = 0
text_end = text_start + config.text_tokens
lang_start = text_end
lang_end = lang_start + config.langs
rvq_start = lang_end
rvq_end = rvq_start + config.resp_levels
prom_start = rvq_end
prom_end = prom_start + config.audio_tokens * config.resp_levels
task_start = prom_end
task_end = task_start + config.tasks
tone_start = task_end
tone_end = tone_start + config.tones
resp_start = tone_end
resp_end = resp_start + config.audio_tokens * config.resp_levels
vocab_size = resp_end
if config.arch_type == "llama":
super().__init__(config=LlamaConfig(
vocab_size=vocab_size,
hidden_size=d_model,
max_position_embeddings=cfg.dataset.frames_per_second * config.max_levels * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
sliding_window=cfg.dataset.frames_per_second * config.max_levels * 12,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
))
if gradient_checkpointing:
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif config.arch_type == "retnet":
super().__init__(config=RetNetConfig(
vocab_size=vocab_size,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
checkpoint_activations=gradient_checkpointing,
activation_fn="gelu",
use_layernorm=False,
use_biases=False,
use_glu=True,
#chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
#recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
#no_output_layer=True,
#rotary_embedding_base=self.rotary_embedding_base, # 10000
decoder_normalize_before=True,
))
elif config.arch_type in ["mamba","mamba2"]:
super().__init__(config=MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layers*2,
d_intermediate=0, # d_model*4,
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if config.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=False,
))
self.backbone.gradient_checkpointing = gradient_checkpointing
self.accuracy_metric = None if True else MulticlassAccuracy(
vocab_size,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=-100,
)
def generate(
self,
*args,
**kwargs
):
if cfg.model.arch_type in ["mamba","mamba2"]:
kwargs["cg"] = True
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
if "do_sample" in kwargs:
kwargs.pop("do_sample")
if "min_length" in kwargs:
kwargs.pop("min_length")
"""
if "position_ids" in kwargs:
kwargs.pop("position_ids")
if "max_new_tokens" in kwargs:
kwargs.pop("max_new_tokens")
if "max_length" not in kwargs:
kwargs["max_length"] = 500 * (self.hyper_config.resp_levels if self.hyper_config.experimental.interleave else 1)
if "num_last_tokens" not in kwargs:
kwargs["num_last_tokens"] = self.hyper_config.experimental.causal_size
"""
input_ids = kwargs.pop("input_ids")
attention_mask = kwargs.pop("attention_mask", None)
position_ids = kwargs.pop("position_ids", None)
stop_token = kwargs.pop("eos_token_id", 3)
max_steps = kwargs.pop("max_new_tokens", 500)
device = input_ids.device
batch_size = input_ids.shape[0]
sequence_list = [ inputs for inputs in input_ids ]
position_list = [ positions for positions in position_ids ]
start_positions = [ inputs.shape[0] for inputs in input_ids ]
stopped = torch.zeros(batch_size, device=device).bool()
config = self.hyper_config
state = None
disable_tqdm = False
causal_size = config.experimental.causal_size
# get next in sequence
for n in trange(max_steps // max(1, causal_size), desc="AR", disable=disable_tqdm):
output = super().forward(
input_ids=torch.stack(sequence_list),
#attention_mask=attention_mask,
#past_key_values=state,
#position_ids=torch.stack(position_list),
#use_cache=False,
#return_dict=False
)
logits = output[0]
# state = output[1]
r = [ logit[-causal_size:].argmax(dim=1) for logit in logits ]
# append tokens
for i, ri in enumerate(r):
if stop_token in ri:
stopped[i] = True
last_position_id = position_list[i][-1].item() + 1
sequence_list[i] = torch.cat([ sequence_list[i], ri.to(device) ], dim=0)
#position_list[i] = torch.cat([ position_list[i], torch.tensor([ last_position_id + _ for _ in range( ri.shape[0] ) ], device=device, dtype=torch.int32) ])
# stop token found
stopped |= r == stop_token
if stopped.all().item():
break
def _prune(l: Tensor, stop = stop_token):
indices = (l == stop).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
sequence_list = [ _prune(seq[start_positions[i]:], stop_token) for i, seq in enumerate(sequence_list) ]
return torch.stack(sequence_list)
"""
return super().generate(*args, **kwargs)
"""
def forward(
self,
*args,
**kwargs,
):
config = self.hyper_config
if "text_list" in kwargs:
text_list = kwargs.pop("text_list", None)
proms_list = kwargs.pop("proms_list", None)
resps_list = kwargs.pop("resps_list", None)
lang_list = kwargs.pop("lang_list", None)
tone_list = kwargs.pop("tone_list", None)
training = kwargs.pop("training", False)
steps = kwargs.pop("steps", 500)
batch_size = len(text_list)
if training:
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
input_ids, attention_mask, position_ids = fold_inputs(
text_list=text_list,
prom_list=proms_list,
resp_list=resps_list,
targ_list=resps_list,
quant_levels=quant_levels,
)
target_ids, target_attention_mask, target_position_ids = fold_inputs(
text_list=text_list,
prom_list=proms_list,
resp_list=resps_list,
targ_list=resps_list,
quant_levels=quant_levels,
ignore_index=-100
)
return self.forward(
input_ids=input_ids,
labels=target_ids,
position_ids=position_ids,
quant_levels=quant_levels,
)
if config.experimental.interleave:
input_ids, attention_mask, position_ids = fold_inputs( text_list=text_list, prom_list=proms_list )
output = self.generate(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
eos_token_id=3,
do_sample=True,
max_new_tokens=steps*config.max_levels,
)
return unfold_outputs( output )["resp_list"]
resps_list = [ [] for _ in range(batch_size) ]
for l in range(config.max_levels):
quant_levels = [ l for _ in range(batch_size) ]
input_ids, attention_mask, position_ids = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 )
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
output = self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
eos_token_id=3,
do_sample=True,
max_new_tokens=steps,
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1]
# store length
if l == 0:
steps = max( steps, length )
# pad
else:
resp = resp[:steps]
if length < steps:
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
resps_list[batch].append( resp )
for i, resp in enumerate( resps_list ):
resps_list[i] = torch.stack( resp ).t()
return resps_list
if config.arch_type in ["mamba","mamba2"]:
kwargs.pop("attention_mask", None)
labels = kwargs.pop("labels", None)
quant_levels = kwargs.pop("quant_levels", None)
output = super().forward(*args, **kwargs)
logits = output.logits
# i HATE the correct way
if labels is not None:
if quant_levels is None:
quant_levels = [0 for _ in range(labels.shape[0])]
# predict the next token for AR, else predict in place
loss = sum([ F.cross_entropy(
logit[:-config.experimental.causal_size, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
label[config.experimental.causal_size:] if quant_level == 0 or "nar" not in config.capabilities else label,
ignore_index=-100
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
self.loss = dict(
nll = loss,
)
if self.accuracy_metric is not None:
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
"""
if config.loss_factors:
sep = 3
# determine specific sections to focus on
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
text_index = 0
resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) / len(logits_text) * self.hyper_config.loss_factor("text")
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) / len(logits_resp) * self.hyper_config.loss_factor("resp")
self.loss = dict(
text = loss_text,
resp = loss_resp,
)
if self.accuracy_metric is not None:
self.stats = dict(
acc = dict(
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
)
)
"""
return output
def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_100
from functools import partial
from einops import repeat
from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model
from ..engines import Engine
from ..utils import wrapper as ml
import numpy as np
import re
device = "cuda"
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.max_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
]
prom_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resp_list = [
qnt[:, :].to(device),
#qnt[cfg.dataset.frames_per_second:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text_list = text_list[:1]
prom_list = prom_list[:1]
resp_list = resp_list[:1]
kwargs = {}
model = Model(**kwargs).to(device)
steps = 50 # 100 if cfg.model.experimental.interleave else 300
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
if cfg.optimizations.dadaptation:
# do not combine the two
if scheduler == "schedulefree":
scheduler = ""
learning_rate = 1.0
if optimizer == "prodigy":
if learning_rate is None:
learning_rate = 1.0
optimizer = ml.Prodigy
elif optimizer == "adagrad":
if learning_rate is None:
learning_rate = 1.0e-2
optimizer = ml.Adagrad
elif optimizer == "adamw":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.AdamW
elif optimizer == "sdg":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.SGD
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):
scheduler = ml.schedulefree.AdamWScheduleFree
elif isinstance(optimizer, ml.SGD):
scheduler = ml.schedulefree.SGDScheduleFree
else:
scheduler = None
if scheduler is not None:
_logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
model = ml.replace_linear( model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
engine = Engine(model=model, optimizer=optimizer)
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
_logger.info(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
engine.eval()
resp_list = model( text_list=text_list, proms_list=prom_list )
for i, batch in enumerate(resp_list):
_ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=prom_list, resps_list=resp_list, training=True)
stats |= engine.gather_attribute("stats")
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
#sample("init", 5)
train()
sample("final")
if __name__ == "__main__":
example_usage()

View File

@ -128,7 +128,7 @@ def get_languages():
return list(get_lang_symmap().keys()) + ["auto"]
def get_tasks():
return ["tts", "sr", "nr", "vc"]
return ["tts", "sr", "ns", "vc"]
#@gradio_wrapper(inputs=layout["dataset"]["inputs"].keys())
def load_sample( speaker ):