From 880b4ecd1b96f54f1aaa32dfa58c0f9447c13234 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 5 Jun 2024 19:50:06 -0500 Subject: [PATCH] cleanup, putting some thoughts in comments before I forget about them --- README.md | 38 ++++++++----- vall_e/models/__init__.py | 6 +- vall_e/models/ar_nar.py | 12 +++- vall_e/models/base.py | 100 +++++++++++++++++++++++----------- vall_e/models/experimental.py | 6 +- vall_e/train.py | 1 - vall_e/utils/trainer.py | 12 ---- 7 files changed, 111 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index 277a245..44d9493 100755 --- a/README.md +++ b/README.md @@ -8,6 +8,8 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), > **Note** Development on this is very sporadic. Gomen. +> **Note** Compatibility for existing models may break at any time while I feverishly try and work out the best way to crank out a model. Gomen. + ## Requirements * [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/): @@ -24,12 +26,14 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`. ## Try Me -To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"` +To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`. -Each model file has a barebones trainer and inference routine. +A small trainer will overfit a provided utterance to ensure a model configuration works. ## Pre-Trained Model +> **Note** Pre-Trained weights aren't up to par until I finally nail the best training methodologies and model code. Gomen. + My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e). A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh` @@ -44,14 +48,14 @@ Training is very dependent on: ### Pre-Processed Dataset +> **Note** The provided dataset needs to be reprocessed to better suit a new training dataset format. Gomen. + A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`. A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh` ### Leverage Your Own Dataset -> **Note** Preparing a dataset is a bit messy. - If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead. 0. Set up a `venv` with `https://github.com/m-bain/whisperX/`. @@ -79,10 +83,11 @@ If you already have a dataset you want, for example your own large corpus, or fo Two dataset formats are supported: * the standard way: - - for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.phn.txt` and `./training/data/{group}/{speaker}/{id}.qnt.pt` - - for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.json` and `./training/data/{group}/{speaker}/{id}.dac` + - for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.enc` as a NumPy file. + - for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.dac` as a NumPy file. + - it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data yaml="./training/config.yaml" --action=metadata` * using an HDF5 dataset: - - you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"` + - you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation) - this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths) - be sure to also define `use_hdf5` in your config YAML. @@ -98,7 +103,6 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`. - ### Plotting Metrics Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"` @@ -107,35 +111,41 @@ You can specify what X and Y labels you want to plot against by passing `--xs to ### Notices -If you're training under `float16`, it is recommended to use the `local` backend with `amp` enabled. There's something really funky with `deepspeed` as a backend that's causing issues with training. - #### Training Under Windows As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend. -Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment with the local trainer. +Creature comforts like `float16`, `amp`, and multi-GPU training *should* work, but extensive testing still needs to be done to ensure it all functions. #### Training Caveats Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with: * too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long. + It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length. + - This does seem to help speed up the model "learning" better. * too tightly trimmed utterances: there being little to no space at the start and end might harm associating `` and `` tokens with empty utterances. * a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping. + + This seems remedied with settling for using a HuggingFace tokenizer to handle everything. +* having a unified AR and NAR model might sound too convenient, but each task may lobotomize the other, due to the nature of things. + + This *might* be remedied with better sequence formatting. #### Backend Architectures As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures: * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. + + I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy). * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. * `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. - Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation. * `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards. * `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. - Its implementation for MoE can also be utilized. -* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model +* `retnet-hf`: using [syncdoth/RetNet](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model - has an inference penality, and MoE is not implemented. +* `mamba`: using [state-spaces/mamba](https://github.com/state-spaces/mamba) (needs to mature) + - ***really hard*** to have a unified AR and NAR model + - inference penalty makes it a really hard sell, despite the loss already being a low 3 after a short amount of samples processed For audio backends: @@ -144,7 +154,7 @@ For audio backends: - encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos` * [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality - **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner. - - **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while. + - **Note** models using `descript-audio-codec` at 44KHz + 8kbps seems harder to model its "language", but despite the loss being rather high, it sounds fine. `llama`-based models also support different attention backends: * `math`: torch's SDPA's `math` implementation @@ -155,6 +165,8 @@ For audio backends: * `sdpa`: integrated `LlamaSdpaAttention` attention model * `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model +The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model. + ## Export To export the models, run: `python -m vall_e.export yaml=./training/config.yaml`. diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 085bfbf..66d0564 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -5,7 +5,8 @@ def get_model(cfg, training=True): if not cfg.experimental: from .ar_nar import AR_NAR model = AR_NAR( - n_tokens=cfg.tokens, + n_text_tokens=cfg.text_tokens, + n_audio_tokens=cfg.audio_tokens, d_model=cfg.dim, n_heads=cfg.heads, n_layers=cfg.layers, @@ -22,6 +23,9 @@ def get_model(cfg, training=True): else: from .experimental import Model as Experimental model = Experimental( + n_text_tokens=cfg.text_tokens, + n_audio_tokens=cfg.audio_tokens, + d_model=cfg.dim, n_layers=cfg.layers, n_heads=cfg.heads, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ac3ad15..a9c3fc3 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -181,6 +181,11 @@ class AR_NAR(Base): # is NAR if max_levels == 0: max_levels = self.n_resp_levels - 1 + + # expand if given a raw 1D tensor + for i, resp in enumerate(resps_list): + if resp.dim() == 1: + resps_list[i] = resp.unsqueeze(-1) prev_list = resps_list @@ -377,7 +382,9 @@ def example_usage(): # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise kwargs = { - 'n_tokens': 1024, + 'n_text_tokens': 256, + 'n_audio_tokens': 1024, + 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 @@ -475,8 +482,7 @@ def example_usage(): else: resps_list = [ qnt[:, 0].to( device ) ] - if cfg.model.max_levels > 1: - resps_list = [r.unsqueeze(-1) for r in resps_list] + if "nar" in cfg.model.capabilities: resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) for i, o in enumerate(resps_list): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 933dc37..d171c48 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -458,7 +458,7 @@ class Base(nn.Module): def stop_token(self): if not self.causal: raise ValueError("Not using stop token!") - return self.n_tokens + return self.n_audio_tokens @property def ignore_index(self): @@ -471,7 +471,10 @@ class Base(nn.Module): def __init__( self, - n_tokens: int = 1024, + + n_text_tokens: int = 256, + n_audio_tokens: int = 1024, + d_model: int = 512, n_heads: int = 8, n_layers: int = 12, @@ -489,7 +492,9 @@ class Base(nn.Module): self.hyper_config = config self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True - self.n_tokens = n_tokens + self.n_text_tokens = n_text_tokens + self.n_audio_tokens = n_audio_tokens + self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers @@ -498,10 +503,10 @@ class Base(nn.Module): self.l_padding = l_padding # +1 to include the stop token - n_prom_tokens = n_tokens - n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop + n_prom_tokens = n_audio_tokens + n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop - self.text_emb = Embedding(n_tokens, d_model) + self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None self.tones_emb = None self.tasks_emb = None @@ -518,24 +523,28 @@ class Base(nn.Module): levels=self.n_prom_levels if self.version > 3 else None, sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True, ) - # [1025] + [1024] * 8 + # [1024 + STOP] + [1024] * 8 self.resps_emb = AudioEmbedding( [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, levels=self.n_resp_levels if self.version > 3 else None, sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True ) - + # useless since I actually removed using these with the input processing overhaul... if self.version >= 3: self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None - + # never actually got added... I kept forgetting to classify all my audio for speaker's tone if self.version >= 4: self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None + # mamba requires this if a model does both AR and NAR tasks + # this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings + # this ***might*** let me also unify the proms_emb and resps_embedding if self.version >= 5: self.rvq_level_emb = Embedding(self.n_resp_levels, d_model) + # this would be nicer to be a stop token or live inside an embedding self.sep = nn.Parameter(torch.randn(d_model)) # ick, there has to be a better way @@ -943,7 +952,6 @@ class Base(nn.Module): target_list = [] for batch_index, batch in enumerate(inputs): target = [] - quant_level = quant_levels[batch_index] if quant_levels is not None else None for name, input in batch: if name == "prom": target.append( torch.full_like(input[..., 0], self.ignore_index) ) @@ -960,19 +968,36 @@ class Base(nn.Module): logits[i] = logits[i][..., :-1, :] # shift the target so that token n... target_list[i] = target_list[i][..., 1:] # predicts token n + 1 - target = torch.cat( target_list ) - inputs = torch.cat( logits ) + # see comments for the split-loss calc cross_entropy call + if False: + target = torch.cat( target_list ) + inputs = torch.cat( logits ) + self.loss = dict( + # "nll" was in the original implementation and should actually just be called something else + nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) + ) + self.stats = dict( + acc = self.accuracy_metric( inputs, target ), + # precision = self.precision_metric( inputs, target ), + ) + else: + self.loss = dict( + nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch) + ) + self.stats = dict( + acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch) + ) - self.loss = dict( - # "nll" was in the original implementation and should actually just be called something else - nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) - ) - self.stats = dict( - acc = self.accuracy_metric( inputs, target ), - # precision = self.precision_metric( inputs, target ), - ) return + """ + # considerations: + # * split losses does not maintain the entire sequence + # * the first token is ignored for all pieces, rather than just the first text token (which is always provided) + # + the other way at least should keep it intact this way + # + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly + # + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes) + """ self.loss = dict() self.stats = dict(acc = dict()) @@ -998,6 +1023,7 @@ class Base(nn.Module): it += seq_len + 1 # +1 to incorporate the separator # for the AR, shift sequence so that it predicts the next token + # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) if quant_level is None or quant_level == 0: logit = logit[..., :-1, :] # get all but the final logit input = input[..., 1:] # shift sequence to the right by one @@ -1008,8 +1034,9 @@ class Base(nn.Module): "logits": [], } - info[name]["targets"].append( input ) # input.contiguous() - info[name]["logits"].append( logit ) # logit.contiguous() + # modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation + info[name]["targets"].append( input.long() ) + info[name]["logits"].append( logit ) for name, batch in info.items(): loss_factor = self.loss_factor(name) @@ -1019,15 +1046,20 @@ class Base(nn.Module): if loss_factor == 0.0: continue - targets = torch.cat( batch["targets"] ).long() - inputs = torch.cat( batch["logits"] ) - - self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor - try: + # "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors + # to-do: set this to a var + if False: + targets = torch.cat( batch["targets"] ).long() + inputs = torch.cat( batch["logits"] ) + self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor self.stats["acc"][name] = self.accuracy_metric( inputs, targets ) - except Exception as e: - print( name, inputs.shape, targets.shape, e ) - pass + # probably consumes less memory due to not having to allocate memory + # this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed) + else: + self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch) + self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch) + + # accuracy sometimes breaks for mamba # to-do: compute loss per individual batch to scale per RVQ level """ @@ -1049,9 +1081,8 @@ class Base(nn.Module): # yes, there's a better way. training = False - for b_i in range(len(inputs)): - for i in range(len(inputs[b_i])): - name, input = inputs[b_i][i] + for batch_index, batch in enumerate(inputs): + for name, input in batch: if name == "targ": training = True @@ -1115,13 +1146,16 @@ class Base(nn.Module): ): if min_temperature < 0: min_temperature = temperature + # (NAR) return the entire generated response + # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) if quant_levels is not None: logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] # (AR chunkwise) return the last chunkwise piece elif self.causal and self.recurrent_chunk_size > 0: logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ] # (AR) return just the last code + # Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously) else: logits = [ logit[-1:] for logit in logits ] diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 471236c..2efb232 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -95,6 +95,10 @@ else: class Model(LlmArchClass): def __init__( self, + + n_text_tokens = 256, + n_audio_tokens = 1024, + d_model=1024, n_layers=12, n_heads=16, @@ -107,7 +111,7 @@ class Model(LlmArchClass): hf_attention = config.attention if config is not None else None gradient_checkpointing = config.gradient_checkpointing if config is not None else True # text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop - vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1 + vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1 if SELECTED_ARCH == "llama": super().__init__(config=LlamaConfig( diff --git a/vall_e/train.py b/vall_e/train.py index d8bcb48..535eb3a 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -174,7 +174,6 @@ def run_eval(engines, eval_name, dl): resps_list = [ resp[:, 0] for resp in batch["resps"] ] if "nar" in engine.hyper_config.capabilities: - resps_list = [ r.unsqueeze(-1) for r in resps_list ] resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) process( name, batch, resps_list ) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 6c3d0b9..9d4e64b 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -164,18 +164,6 @@ def train( stats = engines.step(batch=batch, feeder=train_feeder) stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size()) - """ - stats['batch'] = { - 'size': len(batch['text']), - 'id': batch['spkr_id'], - 'index': [ index for index in batch['index'] ], - 'text_len': [ text.shape[0] for text in batch['text'] ], - 'prom_len': [ prom.shape[0] for prom in batch['proms'] ], - 'resp_len': [ resp.shape[0] for resp in batch['resps'] ], - } - """ - - elapsed_time = stats.get("elapsed_time", 0) try: metrics = json.dumps(stats)