cleanup, putting some thoughts in comments before I forget about them
This commit is contained in:
parent
3cfc8a96bb
commit
880b4ecd1b
38
README.md
38
README.md
|
@ -8,6 +8,8 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
||||||
|
|
||||||
> **Note** Development on this is very sporadic. Gomen.
|
> **Note** Development on this is very sporadic. Gomen.
|
||||||
|
|
||||||
|
> **Note** Compatibility for existing models may break at any time while I feverishly try and work out the best way to crank out a model. Gomen.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
|
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
|
||||||
|
@ -24,12 +26,14 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
|
||||||
|
|
||||||
## Try Me
|
## Try Me
|
||||||
|
|
||||||
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
|
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`.
|
||||||
|
|
||||||
Each model file has a barebones trainer and inference routine.
|
A small trainer will overfit a provided utterance to ensure a model configuration works.
|
||||||
|
|
||||||
## Pre-Trained Model
|
## Pre-Trained Model
|
||||||
|
|
||||||
|
> **Note** Pre-Trained weights aren't up to par until I finally nail the best training methodologies and model code. Gomen.
|
||||||
|
|
||||||
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
|
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
|
||||||
|
|
||||||
A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh`
|
A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh`
|
||||||
|
@ -44,14 +48,14 @@ Training is very dependent on:
|
||||||
|
|
||||||
### Pre-Processed Dataset
|
### Pre-Processed Dataset
|
||||||
|
|
||||||
|
> **Note** The provided dataset needs to be reprocessed to better suit a new training dataset format. Gomen.
|
||||||
|
|
||||||
A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
|
A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
|
||||||
|
|
||||||
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
|
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
|
||||||
|
|
||||||
### Leverage Your Own Dataset
|
### Leverage Your Own Dataset
|
||||||
|
|
||||||
> **Note** Preparing a dataset is a bit messy.
|
|
||||||
|
|
||||||
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
|
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
|
||||||
|
|
||||||
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
|
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
|
||||||
|
@ -79,10 +83,11 @@ If you already have a dataset you want, for example your own large corpus, or fo
|
||||||
|
|
||||||
Two dataset formats are supported:
|
Two dataset formats are supported:
|
||||||
* the standard way:
|
* the standard way:
|
||||||
- for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.phn.txt` and `./training/data/{group}/{speaker}/{id}.qnt.pt`
|
- for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.enc` as a NumPy file.
|
||||||
- for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.json` and `./training/data/{group}/{speaker}/{id}.dac`
|
- for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.dac` as a NumPy file.
|
||||||
|
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data yaml="./training/config.yaml" --action=metadata`
|
||||||
* using an HDF5 dataset:
|
* using an HDF5 dataset:
|
||||||
- you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"`
|
- you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
|
||||||
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
||||||
- be sure to also define `use_hdf5` in your config YAML.
|
- be sure to also define `use_hdf5` in your config YAML.
|
||||||
|
|
||||||
|
@ -98,7 +103,6 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
|
||||||
|
|
||||||
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||||
|
|
||||||
|
|
||||||
### Plotting Metrics
|
### Plotting Metrics
|
||||||
|
|
||||||
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"`
|
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"`
|
||||||
|
@ -107,35 +111,41 @@ You can specify what X and Y labels you want to plot against by passing `--xs to
|
||||||
|
|
||||||
### Notices
|
### Notices
|
||||||
|
|
||||||
If you're training under `float16`, it is recommended to use the `local` backend with `amp` enabled. There's something really funky with `deepspeed` as a backend that's causing issues with training.
|
|
||||||
|
|
||||||
#### Training Under Windows
|
#### Training Under Windows
|
||||||
|
|
||||||
As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
|
As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
|
||||||
|
|
||||||
Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment with the local trainer.
|
Creature comforts like `float16`, `amp`, and multi-GPU training *should* work, but extensive testing still needs to be done to ensure it all functions.
|
||||||
|
|
||||||
#### Training Caveats
|
#### Training Caveats
|
||||||
|
|
||||||
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
|
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
|
||||||
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
|
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
|
||||||
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
|
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
|
||||||
|
- This does seem to help speed up the model "learning" better.
|
||||||
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
|
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
|
||||||
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
|
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
|
||||||
|
+ This seems remedied with settling for using a HuggingFace tokenizer to handle everything.
|
||||||
|
* having a unified AR and NAR model might sound too convenient, but each task may lobotomize the other, due to the nature of things.
|
||||||
|
+ This *might* be remedied with better sequence formatting.
|
||||||
|
|
||||||
#### Backend Architectures
|
#### Backend Architectures
|
||||||
|
|
||||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
|
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
|
||||||
|
|
||||||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||||
|
+ I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy).
|
||||||
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||||
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
||||||
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
|
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
|
||||||
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
|
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
|
||||||
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
||||||
- Its implementation for MoE can also be utilized.
|
- Its implementation for MoE can also be utilized.
|
||||||
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
|
* `retnet-hf`: using [syncdoth/RetNet](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
|
||||||
- has an inference penality, and MoE is not implemented.
|
- has an inference penality, and MoE is not implemented.
|
||||||
|
* `mamba`: using [state-spaces/mamba](https://github.com/state-spaces/mamba) (needs to mature)
|
||||||
|
- ***really hard*** to have a unified AR and NAR model
|
||||||
|
- inference penalty makes it a really hard sell, despite the loss already being a low 3 after a short amount of samples processed
|
||||||
|
|
||||||
For audio backends:
|
For audio backends:
|
||||||
|
|
||||||
|
@ -144,7 +154,7 @@ For audio backends:
|
||||||
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
|
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
|
||||||
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
|
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
|
||||||
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner.
|
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner.
|
||||||
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while.
|
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps seems harder to model its "language", but despite the loss being rather high, it sounds fine.
|
||||||
|
|
||||||
`llama`-based models also support different attention backends:
|
`llama`-based models also support different attention backends:
|
||||||
* `math`: torch's SDPA's `math` implementation
|
* `math`: torch's SDPA's `math` implementation
|
||||||
|
@ -155,6 +165,8 @@ For audio backends:
|
||||||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||||
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
|
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
|
||||||
|
|
||||||
|
The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model.
|
||||||
|
|
||||||
## Export
|
## Export
|
||||||
|
|
||||||
To export the models, run: `python -m vall_e.export yaml=./training/config.yaml`.
|
To export the models, run: `python -m vall_e.export yaml=./training/config.yaml`.
|
||||||
|
|
|
@ -5,7 +5,8 @@ def get_model(cfg, training=True):
|
||||||
if not cfg.experimental:
|
if not cfg.experimental:
|
||||||
from .ar_nar import AR_NAR
|
from .ar_nar import AR_NAR
|
||||||
model = AR_NAR(
|
model = AR_NAR(
|
||||||
n_tokens=cfg.tokens,
|
n_text_tokens=cfg.text_tokens,
|
||||||
|
n_audio_tokens=cfg.audio_tokens,
|
||||||
d_model=cfg.dim,
|
d_model=cfg.dim,
|
||||||
n_heads=cfg.heads,
|
n_heads=cfg.heads,
|
||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
|
@ -22,6 +23,9 @@ def get_model(cfg, training=True):
|
||||||
else:
|
else:
|
||||||
from .experimental import Model as Experimental
|
from .experimental import Model as Experimental
|
||||||
model = Experimental(
|
model = Experimental(
|
||||||
|
n_text_tokens=cfg.text_tokens,
|
||||||
|
n_audio_tokens=cfg.audio_tokens,
|
||||||
|
|
||||||
d_model=cfg.dim,
|
d_model=cfg.dim,
|
||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
n_heads=cfg.heads,
|
n_heads=cfg.heads,
|
||||||
|
|
|
@ -181,6 +181,11 @@ class AR_NAR(Base):
|
||||||
# is NAR
|
# is NAR
|
||||||
if max_levels == 0:
|
if max_levels == 0:
|
||||||
max_levels = self.n_resp_levels - 1
|
max_levels = self.n_resp_levels - 1
|
||||||
|
|
||||||
|
# expand if given a raw 1D tensor
|
||||||
|
for i, resp in enumerate(resps_list):
|
||||||
|
if resp.dim() == 1:
|
||||||
|
resps_list[i] = resp.unsqueeze(-1)
|
||||||
|
|
||||||
prev_list = resps_list
|
prev_list = resps_list
|
||||||
|
|
||||||
|
@ -377,7 +382,9 @@ def example_usage():
|
||||||
|
|
||||||
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_tokens': 1024,
|
'n_text_tokens': 256,
|
||||||
|
'n_audio_tokens': 1024,
|
||||||
|
|
||||||
'd_model': 1024, # 256, # 1024, # 1536
|
'd_model': 1024, # 256, # 1024, # 1536
|
||||||
'n_heads': 16, # 4, # 16, # 24
|
'n_heads': 16, # 4, # 16, # 24
|
||||||
'n_layers': 12, # 32
|
'n_layers': 12, # 32
|
||||||
|
@ -475,8 +482,7 @@ def example_usage():
|
||||||
else:
|
else:
|
||||||
resps_list = [ qnt[:, 0].to( device ) ]
|
resps_list = [ qnt[:, 0].to( device ) ]
|
||||||
|
|
||||||
if cfg.model.max_levels > 1:
|
if "nar" in cfg.model.capabilities:
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
|
||||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||||
|
|
||||||
for i, o in enumerate(resps_list):
|
for i, o in enumerate(resps_list):
|
||||||
|
|
|
@ -458,7 +458,7 @@ class Base(nn.Module):
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
if not self.causal:
|
if not self.causal:
|
||||||
raise ValueError("Not using stop token!")
|
raise ValueError("Not using stop token!")
|
||||||
return self.n_tokens
|
return self.n_audio_tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_index(self):
|
def ignore_index(self):
|
||||||
|
@ -471,7 +471,10 @@ class Base(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_tokens: int = 1024,
|
|
||||||
|
n_text_tokens: int = 256,
|
||||||
|
n_audio_tokens: int = 1024,
|
||||||
|
|
||||||
d_model: int = 512,
|
d_model: int = 512,
|
||||||
n_heads: int = 8,
|
n_heads: int = 8,
|
||||||
n_layers: int = 12,
|
n_layers: int = 12,
|
||||||
|
@ -489,7 +492,9 @@ class Base(nn.Module):
|
||||||
self.hyper_config = config
|
self.hyper_config = config
|
||||||
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
|
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
|
||||||
|
|
||||||
self.n_tokens = n_tokens
|
self.n_text_tokens = n_text_tokens
|
||||||
|
self.n_audio_tokens = n_audio_tokens
|
||||||
|
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
|
@ -498,10 +503,10 @@ class Base(nn.Module):
|
||||||
self.l_padding = l_padding
|
self.l_padding = l_padding
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
n_prom_tokens = n_tokens
|
n_prom_tokens = n_audio_tokens
|
||||||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
self.tones_emb = None
|
self.tones_emb = None
|
||||||
self.tasks_emb = None
|
self.tasks_emb = None
|
||||||
|
@ -518,24 +523,28 @@ class Base(nn.Module):
|
||||||
levels=self.n_prom_levels if self.version > 3 else None,
|
levels=self.n_prom_levels if self.version > 3 else None,
|
||||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
|
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
|
||||||
)
|
)
|
||||||
# [1025] + [1024] * 8
|
# [1024 + STOP] + [1024] * 8
|
||||||
self.resps_emb = AudioEmbedding(
|
self.resps_emb = AudioEmbedding(
|
||||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||||
levels=self.n_resp_levels if self.version > 3 else None,
|
levels=self.n_resp_levels if self.version > 3 else None,
|
||||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
|
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# useless since I actually removed using these with the input processing overhaul...
|
||||||
if self.version >= 3:
|
if self.version >= 3:
|
||||||
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
|
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
|
||||||
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
|
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
|
||||||
|
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
|
||||||
if self.version >= 4:
|
if self.version >= 4:
|
||||||
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
|
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
|
||||||
|
|
||||||
|
# mamba requires this if a model does both AR and NAR tasks
|
||||||
|
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
|
||||||
|
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||||
if self.version >= 5:
|
if self.version >= 5:
|
||||||
self.rvq_level_emb = Embedding(self.n_resp_levels, d_model)
|
self.rvq_level_emb = Embedding(self.n_resp_levels, d_model)
|
||||||
|
|
||||||
|
# this would be nicer to be a stop token or live inside an embedding
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
# ick, there has to be a better way
|
# ick, there has to be a better way
|
||||||
|
@ -943,7 +952,6 @@ class Base(nn.Module):
|
||||||
target_list = []
|
target_list = []
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
target = []
|
target = []
|
||||||
quant_level = quant_levels[batch_index] if quant_levels is not None else None
|
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "prom":
|
if name == "prom":
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
|
@ -960,19 +968,36 @@ class Base(nn.Module):
|
||||||
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
||||||
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
||||||
|
|
||||||
target = torch.cat( target_list )
|
# see comments for the split-loss calc cross_entropy call
|
||||||
inputs = torch.cat( logits )
|
if False:
|
||||||
|
target = torch.cat( target_list )
|
||||||
|
inputs = torch.cat( logits )
|
||||||
|
self.loss = dict(
|
||||||
|
# "nll" was in the original implementation and should actually just be called something else
|
||||||
|
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||||
|
)
|
||||||
|
self.stats = dict(
|
||||||
|
acc = self.accuracy_metric( inputs, target ),
|
||||||
|
# precision = self.precision_metric( inputs, target ),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.loss = dict(
|
||||||
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch)
|
||||||
|
)
|
||||||
|
self.stats = dict(
|
||||||
|
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch)
|
||||||
|
)
|
||||||
|
|
||||||
self.loss = dict(
|
|
||||||
# "nll" was in the original implementation and should actually just be called something else
|
|
||||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
|
||||||
)
|
|
||||||
self.stats = dict(
|
|
||||||
acc = self.accuracy_metric( inputs, target ),
|
|
||||||
# precision = self.precision_metric( inputs, target ),
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
"""
|
||||||
|
# considerations:
|
||||||
|
# * split losses does not maintain the entire sequence
|
||||||
|
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
|
||||||
|
# + the other way at least should keep it intact this way
|
||||||
|
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||||
|
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||||
|
"""
|
||||||
self.loss = dict()
|
self.loss = dict()
|
||||||
self.stats = dict(acc = dict())
|
self.stats = dict(acc = dict())
|
||||||
|
|
||||||
|
@ -998,6 +1023,7 @@ class Base(nn.Module):
|
||||||
it += seq_len + 1 # +1 to incorporate the separator
|
it += seq_len + 1 # +1 to incorporate the separator
|
||||||
|
|
||||||
# for the AR, shift sequence so that it predicts the next token
|
# for the AR, shift sequence so that it predicts the next token
|
||||||
|
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||||
if quant_level is None or quant_level == 0:
|
if quant_level is None or quant_level == 0:
|
||||||
logit = logit[..., :-1, :] # get all but the final logit
|
logit = logit[..., :-1, :] # get all but the final logit
|
||||||
input = input[..., 1:] # shift sequence to the right by one
|
input = input[..., 1:] # shift sequence to the right by one
|
||||||
|
@ -1008,8 +1034,9 @@ class Base(nn.Module):
|
||||||
"logits": [],
|
"logits": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
info[name]["targets"].append( input ) # input.contiguous()
|
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
|
||||||
info[name]["logits"].append( logit ) # logit.contiguous()
|
info[name]["targets"].append( input.long() )
|
||||||
|
info[name]["logits"].append( logit )
|
||||||
|
|
||||||
for name, batch in info.items():
|
for name, batch in info.items():
|
||||||
loss_factor = self.loss_factor(name)
|
loss_factor = self.loss_factor(name)
|
||||||
|
@ -1019,15 +1046,20 @@ class Base(nn.Module):
|
||||||
if loss_factor == 0.0:
|
if loss_factor == 0.0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
targets = torch.cat( batch["targets"] ).long()
|
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
|
||||||
inputs = torch.cat( batch["logits"] )
|
# to-do: set this to a var
|
||||||
|
if False:
|
||||||
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
targets = torch.cat( batch["targets"] ).long()
|
||||||
try:
|
inputs = torch.cat( batch["logits"] )
|
||||||
|
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||||
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||||
except Exception as e:
|
# probably consumes less memory due to not having to allocate memory
|
||||||
print( name, inputs.shape, targets.shape, e )
|
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||||
pass
|
else:
|
||||||
|
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch)
|
||||||
|
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch)
|
||||||
|
|
||||||
|
# accuracy sometimes breaks for mamba
|
||||||
|
|
||||||
# to-do: compute loss per individual batch to scale per RVQ level
|
# to-do: compute loss per individual batch to scale per RVQ level
|
||||||
"""
|
"""
|
||||||
|
@ -1049,9 +1081,8 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# yes, there's a better way.
|
# yes, there's a better way.
|
||||||
training = False
|
training = False
|
||||||
for b_i in range(len(inputs)):
|
for batch_index, batch in enumerate(inputs):
|
||||||
for i in range(len(inputs[b_i])):
|
for name, input in batch:
|
||||||
name, input = inputs[b_i][i]
|
|
||||||
if name == "targ":
|
if name == "targ":
|
||||||
training = True
|
training = True
|
||||||
|
|
||||||
|
@ -1115,13 +1146,16 @@ class Base(nn.Module):
|
||||||
):
|
):
|
||||||
if min_temperature < 0:
|
if min_temperature < 0:
|
||||||
min_temperature = temperature
|
min_temperature = temperature
|
||||||
|
|
||||||
# (NAR) return the entire generated response
|
# (NAR) return the entire generated response
|
||||||
|
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||||
# (AR chunkwise) return the last chunkwise piece
|
# (AR chunkwise) return the last chunkwise piece
|
||||||
elif self.causal and self.recurrent_chunk_size > 0:
|
elif self.causal and self.recurrent_chunk_size > 0:
|
||||||
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
|
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
|
||||||
# (AR) return just the last code
|
# (AR) return just the last code
|
||||||
|
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
|
||||||
else:
|
else:
|
||||||
logits = [ logit[-1:] for logit in logits ]
|
logits = [ logit[-1:] for logit in logits ]
|
||||||
|
|
||||||
|
|
|
@ -95,6 +95,10 @@ else:
|
||||||
class Model(LlmArchClass):
|
class Model(LlmArchClass):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
|
n_text_tokens = 256,
|
||||||
|
n_audio_tokens = 1024,
|
||||||
|
|
||||||
d_model=1024,
|
d_model=1024,
|
||||||
n_layers=12,
|
n_layers=12,
|
||||||
n_heads=16,
|
n_heads=16,
|
||||||
|
@ -107,7 +111,7 @@ class Model(LlmArchClass):
|
||||||
hf_attention = config.attention if config is not None else None
|
hf_attention = config.attention if config is not None else None
|
||||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||||
vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||||
|
|
||||||
if SELECTED_ARCH == "llama":
|
if SELECTED_ARCH == "llama":
|
||||||
super().__init__(config=LlamaConfig(
|
super().__init__(config=LlamaConfig(
|
||||||
|
|
|
@ -174,7 +174,6 @@ def run_eval(engines, eval_name, dl):
|
||||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||||
|
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
|
@ -164,18 +164,6 @@ def train(
|
||||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||||
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
|
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
|
||||||
|
|
||||||
"""
|
|
||||||
stats['batch'] = {
|
|
||||||
'size': len(batch['text']),
|
|
||||||
'id': batch['spkr_id'],
|
|
||||||
'index': [ index for index in batch['index'] ],
|
|
||||||
'text_len': [ text.shape[0] for text in batch['text'] ],
|
|
||||||
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
|
||||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
elapsed_time = stats.get("elapsed_time", 0)
|
elapsed_time = stats.get("elapsed_time", 0)
|
||||||
try:
|
try:
|
||||||
metrics = json.dumps(stats)
|
metrics = json.dumps(stats)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user