fixed training stats not loading from exported weights, a bit of a readme cleanup, updated example training yaml
This commit is contained in:
parent
9384900ce6
commit
4abd6564d1
39
README.md
39
README.md
|
@ -4,7 +4,11 @@
|
|||
|
||||
# VALL'E
|
||||
|
||||
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), based on the [EnCodec](https://github.com/facebookresearch/encodec) tokenizer.
|
||||
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
|
||||
|
||||
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) | [HuggingFace Space](https://huggingface.co/spaces/ecker/vall-e)
|
||||
|
||||
> **Note** This README is still quite a disorganized mess.
|
||||
|
||||
## Requirements
|
||||
|
||||
|
@ -32,11 +36,7 @@ A HuggingFace space hosting the code and models can be found [here](https://hugg
|
|||
|
||||
### Local
|
||||
|
||||
To quickly try it out, you can choose between the following modes:
|
||||
|
||||
* AR only: `python -m vall_e.models.ar yaml="./data/config.yaml"`
|
||||
* NAR only: `python -m vall_e.models.nar yaml="./data/config.yaml"`
|
||||
* AR+NAR: `python -m vall_e.models.base yaml="./data/config.yaml"`
|
||||
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
|
||||
|
||||
Each model file has a barebones trainer and inference routine.
|
||||
|
||||
|
@ -77,6 +77,7 @@ A "libre" dataset can be found [here](https://huggingface.co/ecker/vall-e/blob/m
|
|||
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'`
|
||||
|
||||
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
|
||||
* If distributing your training (for example, multi-GPU), use `deepspeed --module vall_e.train yaml="./data/config.yaml"`
|
||||
|
||||
You may quit your training any time by just entering `quit` in your CLI. The latest checkpoint will be automatically saved.
|
||||
|
||||
|
@ -98,15 +99,11 @@ You can specify what X and Y labels you want to plot against by passing `--xs to
|
|||
|
||||
### Notices
|
||||
|
||||
#### Modifying `prom_levels`, `resp_levels`, Or `tasks` For A Model
|
||||
|
||||
If you're wanting to increase the `prom_levels` for a given model, or increase the `tasks` levels a model accepts, you will need to export your weights and set `train.load_state_dict` to `True` in your configuration YAML.
|
||||
|
||||
#### Training Under Windows
|
||||
|
||||
As training under `deepspeed` is not supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
|
||||
As training under `deepspeed` and Windows is not supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
|
||||
|
||||
Keep in mind that creature comforts like distributed training cannot be verified as working at the moment.
|
||||
Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment.
|
||||
|
||||
#### Training on Low-VRAM Cards
|
||||
|
||||
|
@ -116,13 +113,11 @@ VRAM use is also predicated on your dataset; a mix of large and small utterances
|
|||
|
||||
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
|
||||
|
||||
If you need to, you are free to train only one model at a time. Just remove the definition for one model in your `config.yaml`'s `models._model` list.
|
||||
|
||||
## Export
|
||||
|
||||
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
|
||||
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
|
||||
|
||||
This will export the latest checkpoints, for example, under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth`, to be loaded on any system with PyTorch.
|
||||
This will export the latest checkpoints, for example, under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
|
||||
|
||||
## Synthesis
|
||||
|
||||
|
@ -149,17 +144,17 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
|
||||
## To-Do
|
||||
|
||||
* reduce load time for creating / preparing dataloaders (hint: remove use of `Path.glob` and `Path.rglob`).
|
||||
* train and release a ***good*** model.
|
||||
* extend to multiple languages (VALL-E X) and ~~extend to~~ train SpeechX features.
|
||||
+ This can easily be done with adding in additional embeddings + tokens, rather than cramming into the input prompt embedding.
|
||||
## Notice
|
||||
* clean up the README, and document, document, document onto the wiki.
|
||||
* extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
|
||||
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
||||
## Notices and Citations
|
||||
|
||||
Unless otherwise credited/noted, this repository is [licensed](LICENSE) under AGPLv3.
|
||||
|
||||
## Citations
|
||||
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
||||
|
||||
- This implementation was originally based on [enhuiz/vall-e](https://github.com/enhuiz/vall-e), but has been heavily, heavily modified over time.
|
||||
|
||||
```bibtex
|
||||
@article{wang2023neural,
|
||||
|
|
|
@ -1,60 +1,57 @@
|
|||
dataset:
|
||||
training: [
|
||||
# "./training/valle/data/LibriTTS/994/",
|
||||
]
|
||||
|
||||
validation: [
|
||||
# "./training/valle/data/Validation/1188/",
|
||||
]
|
||||
noise: [
|
||||
# "./training/valle/data/Other/noise/",
|
||||
]
|
||||
training: []
|
||||
validation: []
|
||||
noise: []
|
||||
|
||||
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
|
||||
use_hdf5: True
|
||||
use_metadata: True
|
||||
hdf5_flag: r
|
||||
validate: True
|
||||
|
||||
workers: 4
|
||||
workers: 2
|
||||
cache: True
|
||||
|
||||
phones_range: [4, 512]
|
||||
phones_range: [4, 256]
|
||||
duration_range: [1.0, 16.0]
|
||||
min_utterances: 32
|
||||
|
||||
random_utterance: 1.0
|
||||
max_prompts: 3
|
||||
prompt_duration: 3.0
|
||||
max_prompts: 6
|
||||
prompt_duration: 6.0
|
||||
|
||||
sample_type: speaker
|
||||
|
||||
tasks_list: ["tts"] # ["tts", "ns", "sr", "tse", "cse", "nse"]
|
||||
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
|
||||
|
||||
models:
|
||||
_prom_levels: 8
|
||||
_max_levels: 8
|
||||
_models:
|
||||
- name: "ar"
|
||||
size: "full"
|
||||
resp_levels: 1
|
||||
prom_levels: 2
|
||||
tasks: 8
|
||||
arch_type: "retnet"
|
||||
|
||||
- name: "nar"
|
||||
size: "full"
|
||||
resp_levels: 1
|
||||
prom_levels: 2
|
||||
_models:
|
||||
- name: "ar+nar"
|
||||
size: "double"
|
||||
resp_levels: 8
|
||||
prom_levels: 8
|
||||
tasks: 8
|
||||
arch_type: "retnet"
|
||||
training: True
|
||||
version: 2
|
||||
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 16
|
||||
gradient_accumulation_steps: 4
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_clipping: 100
|
||||
|
||||
# prodigyopt is nicer, but requires even more VRAM
|
||||
#optimizer: Prodigy
|
||||
#learning_rate: 1.0 # e-4
|
||||
|
||||
optimizer: AdamW
|
||||
learning_rate: 1.0e-4
|
||||
torch_optimizer: True
|
||||
|
||||
scheduler_type: ""
|
||||
#scheduler_type: OneCycle
|
||||
|
@ -77,12 +74,13 @@ hyperparameters:
|
|||
|
||||
evaluation:
|
||||
batch_size: 16
|
||||
frequency: 500
|
||||
frequency: 250
|
||||
size: 16
|
||||
|
||||
steps: 300
|
||||
steps: 450
|
||||
ar_temperature: 0.95
|
||||
nar_temperature: 0.25
|
||||
load_disabled_engines: True
|
||||
|
||||
trainer:
|
||||
iterations: 1_000_000
|
||||
|
@ -90,11 +88,13 @@ trainer:
|
|||
save_tag: step
|
||||
save_on_oom: True
|
||||
save_on_quit: True
|
||||
save_frequency: 1000
|
||||
save_frequency: 100
|
||||
export_on_save: True
|
||||
|
||||
keep_last_checkpoints: 4
|
||||
|
||||
aggressive_optimizations: False
|
||||
load_disabled_engines: False
|
||||
|
||||
#load_state_dict: True
|
||||
#strict_loading: False
|
||||
|
@ -105,18 +105,21 @@ trainer:
|
|||
gc_mode: None # "global_step"
|
||||
|
||||
weight_dtype: bfloat16
|
||||
amp: False
|
||||
|
||||
backend: deepspeed
|
||||
deepspeed:
|
||||
zero_optimization_level: 0
|
||||
use_compression_training: True
|
||||
|
||||
activation_checkpointing: True
|
||||
|
||||
inference:
|
||||
use_vocos: True
|
||||
normalize: False # do NOT change this unless you know exactly what you are doing.
|
||||
normalize: False
|
||||
|
||||
bitsandbytes:
|
||||
enabled: False
|
||||
injects: True
|
||||
linear: True
|
||||
embedding: True
|
||||
injects: False
|
||||
linear: False
|
||||
embedding: False
|
||||
|
|
|
@ -40,8 +40,8 @@ class Engine(DeepSpeedEngine):
|
|||
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
||||
|
||||
stats = {
|
||||
"global_steps": 0,
|
||||
"micro_steps": 0,
|
||||
"global_step": 0,
|
||||
"micro_step": 0,
|
||||
"global_samples": 0,
|
||||
"tokens_processed": 0,
|
||||
}
|
||||
|
@ -54,8 +54,8 @@ class Engine(DeepSpeedEngine):
|
|||
super().__init__(None, *args, **kwargs)
|
||||
self._frozen_params = set()
|
||||
|
||||
self.global_steps = stats["global_steps"]
|
||||
self.micro_steps = stats["micro_steps"]
|
||||
self.global_steps = stats["global_step"]
|
||||
self.micro_steps = stats["micro_step"]
|
||||
self.global_samples = stats["global_samples"]
|
||||
self.tokens_processed = stats["tokens_processed"]
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ def load_engines():
|
|||
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
if "stats" in state:
|
||||
additionals = state["stats"]
|
||||
stats = state["stats"]
|
||||
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user