cleanup, putting some thoughts in comments before I forget about them

This commit is contained in:
mrq 2024-06-05 19:50:06 -05:00
parent 3cfc8a96bb
commit 880b4ecd1b
7 changed files with 111 additions and 64 deletions

View File

@ -8,6 +8,8 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
> **Note** Development on this is very sporadic. Gomen.
> **Note** Compatibility for existing models may break at any time while I feverishly try and work out the best way to crank out a model. Gomen.
## Requirements
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
@ -24,12 +26,14 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
## Try Me
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`.
Each model file has a barebones trainer and inference routine.
A small trainer will overfit a provided utterance to ensure a model configuration works.
## Pre-Trained Model
> **Note** Pre-Trained weights aren't up to par until I finally nail the best training methodologies and model code. Gomen.
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh`
@ -44,14 +48,14 @@ Training is very dependent on:
### Pre-Processed Dataset
> **Note** The provided dataset needs to be reprocessed to better suit a new training dataset format. Gomen.
A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
### Leverage Your Own Dataset
> **Note** Preparing a dataset is a bit messy.
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
@ -79,10 +83,11 @@ If you already have a dataset you want, for example your own large corpus, or fo
Two dataset formats are supported:
* the standard way:
- for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.phn.txt` and `./training/data/{group}/{speaker}/{id}.qnt.pt`
- for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.json` and `./training/data/{group}/{speaker}/{id}.dac`
- for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.enc` as a NumPy file.
- for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.dac` as a NumPy file.
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data yaml="./training/config.yaml" --action=metadata`
* using an HDF5 dataset:
- you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"`
- you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
- be sure to also define `use_hdf5` in your config YAML.
@ -98,7 +103,6 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
### Plotting Metrics
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"`
@ -107,35 +111,41 @@ You can specify what X and Y labels you want to plot against by passing `--xs to
### Notices
If you're training under `float16`, it is recommended to use the `local` backend with `amp` enabled. There's something really funky with `deepspeed` as a backend that's causing issues with training.
#### Training Under Windows
As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment with the local trainer.
Creature comforts like `float16`, `amp`, and multi-GPU training *should* work, but extensive testing still needs to be done to ensure it all functions.
#### Training Caveats
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
- This does seem to help speed up the model "learning" better.
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
+ This seems remedied with settling for using a HuggingFace tokenizer to handle everything.
* having a unified AR and NAR model might sound too convenient, but each task may lobotomize the other, due to the nature of things.
+ This *might* be remedied with better sequence formatting.
#### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
+ I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy).
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
- Its implementation for MoE can also be utilized.
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
* `retnet-hf`: using [syncdoth/RetNet](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
- has an inference penality, and MoE is not implemented.
* `mamba`: using [state-spaces/mamba](https://github.com/state-spaces/mamba) (needs to mature)
- ***really hard*** to have a unified AR and NAR model
- inference penalty makes it a really hard sell, despite the loss already being a low 3 after a short amount of samples processed
For audio backends:
@ -144,7 +154,7 @@ For audio backends:
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner.
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while.
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps seems harder to model its "language", but despite the loss being rather high, it sounds fine.
`llama`-based models also support different attention backends:
* `math`: torch's SDPA's `math` implementation
@ -155,6 +165,8 @@ For audio backends:
* `sdpa`: integrated `LlamaSdpaAttention` attention model
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model.
## Export
To export the models, run: `python -m vall_e.export yaml=./training/config.yaml`.

View File

@ -5,7 +5,8 @@ def get_model(cfg, training=True):
if not cfg.experimental:
from .ar_nar import AR_NAR
model = AR_NAR(
n_tokens=cfg.tokens,
n_text_tokens=cfg.text_tokens,
n_audio_tokens=cfg.audio_tokens,
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
@ -22,6 +23,9 @@ def get_model(cfg, training=True):
else:
from .experimental import Model as Experimental
model = Experimental(
n_text_tokens=cfg.text_tokens,
n_audio_tokens=cfg.audio_tokens,
d_model=cfg.dim,
n_layers=cfg.layers,
n_heads=cfg.heads,

View File

@ -181,6 +181,11 @@ class AR_NAR(Base):
# is NAR
if max_levels == 0:
max_levels = self.n_resp_levels - 1
# expand if given a raw 1D tensor
for i, resp in enumerate(resps_list):
if resp.dim() == 1:
resps_list[i] = resp.unsqueeze(-1)
prev_list = resps_list
@ -377,7 +382,9 @@ def example_usage():
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
'n_tokens': 1024,
'n_text_tokens': 256,
'n_audio_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
@ -475,8 +482,7 @@ def example_usage():
else:
resps_list = [ qnt[:, 0].to( device ) ]
if cfg.model.max_levels > 1:
resps_list = [r.unsqueeze(-1) for r in resps_list]
if "nar" in cfg.model.capabilities:
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
for i, o in enumerate(resps_list):

View File

@ -458,7 +458,7 @@ class Base(nn.Module):
def stop_token(self):
if not self.causal:
raise ValueError("Not using stop token!")
return self.n_tokens
return self.n_audio_tokens
@property
def ignore_index(self):
@ -471,7 +471,10 @@ class Base(nn.Module):
def __init__(
self,
n_tokens: int = 1024,
n_text_tokens: int = 256,
n_audio_tokens: int = 1024,
d_model: int = 512,
n_heads: int = 8,
n_layers: int = 12,
@ -489,7 +492,9 @@ class Base(nn.Module):
self.hyper_config = config
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
self.n_tokens = n_tokens
self.n_text_tokens = n_text_tokens
self.n_audio_tokens = n_audio_tokens
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
@ -498,10 +503,10 @@ class Base(nn.Module):
self.l_padding = l_padding
# +1 to include the stop token
n_prom_tokens = n_tokens
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
n_prom_tokens = n_audio_tokens
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
self.text_emb = Embedding(n_tokens, d_model)
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
self.tones_emb = None
self.tasks_emb = None
@ -518,24 +523,28 @@ class Base(nn.Module):
levels=self.n_prom_levels if self.version > 3 else None,
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
)
# [1025] + [1024] * 8
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None,
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
)
# useless since I actually removed using these with the input processing overhaul...
if self.version >= 3:
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
if self.version >= 4:
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
# mamba requires this if a model does both AR and NAR tasks
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
if self.version >= 5:
self.rvq_level_emb = Embedding(self.n_resp_levels, d_model)
# this would be nicer to be a stop token or live inside an embedding
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
@ -943,7 +952,6 @@ class Base(nn.Module):
target_list = []
for batch_index, batch in enumerate(inputs):
target = []
quant_level = quant_levels[batch_index] if quant_levels is not None else None
for name, input in batch:
if name == "prom":
target.append( torch.full_like(input[..., 0], self.ignore_index) )
@ -960,19 +968,36 @@ class Base(nn.Module):
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
target = torch.cat( target_list )
inputs = torch.cat( logits )
# see comments for the split-loss calc cross_entropy call
if False:
target = torch.cat( target_list )
inputs = torch.cat( logits )
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
else:
self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch)
)
self.stats = dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch)
)
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
return
"""
# considerations:
# * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
self.loss = dict()
self.stats = dict(acc = dict())
@ -998,6 +1023,7 @@ class Base(nn.Module):
it += seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level is None or quant_level == 0:
logit = logit[..., :-1, :] # get all but the final logit
input = input[..., 1:] # shift sequence to the right by one
@ -1008,8 +1034,9 @@ class Base(nn.Module):
"logits": [],
}
info[name]["targets"].append( input ) # input.contiguous()
info[name]["logits"].append( logit ) # logit.contiguous()
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
info[name]["targets"].append( input.long() )
info[name]["logits"].append( logit )
for name, batch in info.items():
loss_factor = self.loss_factor(name)
@ -1019,15 +1046,20 @@ class Base(nn.Module):
if loss_factor == 0.0:
continue
targets = torch.cat( batch["targets"] ).long()
inputs = torch.cat( batch["logits"] )
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
try:
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
# to-do: set this to a var
if False:
targets = torch.cat( batch["targets"] ).long()
inputs = torch.cat( batch["logits"] )
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
except Exception as e:
print( name, inputs.shape, targets.shape, e )
pass
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch)
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch)
# accuracy sometimes breaks for mamba
# to-do: compute loss per individual batch to scale per RVQ level
"""
@ -1049,9 +1081,8 @@ class Base(nn.Module):
# yes, there's a better way.
training = False
for b_i in range(len(inputs)):
for i in range(len(inputs[b_i])):
name, input = inputs[b_i][i]
for batch_index, batch in enumerate(inputs):
for name, input in batch:
if name == "targ":
training = True
@ -1115,13 +1146,16 @@ class Base(nn.Module):
):
if min_temperature < 0:
min_temperature = temperature
# (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None:
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece
elif self.causal and self.recurrent_chunk_size > 0:
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
# (AR) return just the last code
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
else:
logits = [ logit[-1:] for logit in logits ]

View File

@ -95,6 +95,10 @@ else:
class Model(LlmArchClass):
def __init__(
self,
n_text_tokens = 256,
n_audio_tokens = 1024,
d_model=1024,
n_layers=12,
n_heads=16,
@ -107,7 +111,7 @@ class Model(LlmArchClass):
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
if SELECTED_ARCH == "llama":
super().__init__(config=LlamaConfig(

View File

@ -174,7 +174,6 @@ def run_eval(engines, eval_name, dl):
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
if "nar" in engine.hyper_config.capabilities:
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
process( name, batch, resps_list )

View File

@ -164,18 +164,6 @@ def train(
stats = engines.step(batch=batch, feeder=train_feeder)
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
"""
stats['batch'] = {
'size': len(batch['text']),
'id': batch['spkr_id'],
'index': [ index for index in batch['index'] ],
'text_len': [ text.shape[0] for text in batch['text'] ],
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
"""
elapsed_time = stats.get("elapsed_time", 0)
try:
metrics = json.dumps(stats)