finally got around to removing omegaconf

This commit is contained in:
mrq 2024-06-07 20:23:53 -05:00
parent 4ade2b60ee
commit da8242d086
4 changed files with 99 additions and 73 deletions

View File

@ -12,21 +12,22 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
## Requirements
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
+ additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
Besides a working PyTorch environment, the only hard requirement is [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
- Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
+ additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
- In the future, an internal homebrew to replace this *would* be fantastic.
## Install
Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`.
I've tested this repo under Python versions `3.10.9` and `3.11.3`.
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
## Try Me
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`.
To quickly try it out, you can run `python -m vall_e.models.ar_nar --yaml="./data/config.yaml"`.
A small trainer will overfit a provided utterance to ensure a model configuration works.
@ -85,19 +86,19 @@ Two dataset formats are supported:
* the standard way:
- for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.enc` as a NumPy file.
- for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.dac` as a NumPy file.
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data yaml="./training/config.yaml" --action=metadata`
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data --yaml="./training/config.yaml" --action=metadata`
* using an HDF5 dataset:
- you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
- you can convert from the standard way with the following command: `python3 -m vall_e.data --yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
- be sure to also define `use_hdf5` in your config YAML.
### Training
For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`.
For single GPUs, simply running `python3 -m vall_e.train --yaml="./training/config.yaml`.
For multiple GPUs, or exotic distributed training:
* with `deepspeed` backends, simply running `deepspeed --module vall_e.train yaml="./training/config.yaml"` should handle the gory details.
* with `local` backends, simply run `torchrun --nnodes=1 --nproc-per-node={NUMOFGPUS} -m vall_e.train yaml="./training/config.yaml"`
* with `deepspeed` backends, simply running `deepspeed --module vall_e.train --yaml="./training/config.yaml"` should handle the gory details.
* with `local` backends, simply run `torchrun --nnodes=1 --nproc-per-node={NUMOFGPUS} -m vall_e.train --yaml="./training/config.yaml"`
You can enter `save` to save the state at any time, or `quit` to save and quit training.
@ -105,7 +106,7 @@ The `lr` will also let you adjust the learning rate on the fly. For example: `lr
### Plotting Metrics
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"`
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot --yaml="./training/config.yaml"`
You can specify what X and Y labels you want to plot against by passing `--xs tokens_processed --ys loss stats.acc`
@ -127,7 +128,7 @@ Unfortunately, efforts to train a *good* foundational model seems entirely predi
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
+ This seems remedied with settling for using a HuggingFace tokenizer to handle everything.
* having a unified AR and NAR model might sound too convenient, but each task may lobotomize the other, due to the nature of things.
+ This *might* be remedied with better sequence formatting.
+ This *might* be remedied with better sequence formatting, or separate embeddings for the AR/NAR
#### Backend Architectures
@ -169,13 +170,13 @@ The wide support for various backends is solely while I try and figure out which
## Export
To export the models, run: `python -m vall_e.export yaml=./training/config.yaml`.
To export the models, run: `python -m vall_e.export --yaml=./training/config.yaml`.
This will export the latest checkpoints, for example, under `./training/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
## Synthesis
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --model-ckpt ./training/ckpt/ar+nar-retnet-8/fp32.pth` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
To synthesize speech: `python -m vall_e <text> <ref_path> <out_path> --yaml=<yaml_path>`
Some additional flags you can pass are:
* `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language.
@ -204,17 +205,22 @@ And some experimental sampling flags you can use too (your mileage will ***defin
## To-Do
* train and release a ***good*** model.
* explore alternative setups, like a NAR-only model
- this would require a audio length predictor, but could help with a lot of things (I believe Meta's Voicebox does this?)
* explore better sampling techniques
- dynamic temperature shows promise despite it being a very early iteration
- mirostat seems to show promise too despite being a half-baked implementation
- penalty incurred from sampling is a bit steep at times...
- the NAR might need to be greedy sampled only
* clean up the README, and document, document, document onto the wiki.
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- training additional tasks needs the SpeechX implementation to be reworked.
- this requires a good foundational model before extending it to transfer tasks onto.
* improve throughput (despite peaking at 120it/s):
- properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained).
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
+ this requires a properly trained AR, however.
* work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this):
- "sliding" AR input, such as have the context a fixed length.
+ the model may need to be trained for this with a fancy positional embedding injected OR already trained with a sliding context window in mind. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful.
* audio streaming
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
## Notices and Citations
@ -222,7 +228,7 @@ Unless otherwise credited/noted in this README or within the designated Python f
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
- This implementation was originally based on [enhuiz/vall-e](https://github.com/enhuiz/vall-e), but has been heavily, heavily modified over time.
- This implementation was originally based on [enhuiz/vall-e](https://github.com/enhuiz/vall-e), but has been heavily, heavily modified over time. Without it I would not have had a good basis to muck around and learn.
```bibtex
@article{wang2023neural,

View File

@ -8,7 +8,6 @@ def shell(*args):
out = subprocess.check_output(args)
return out.decode("ascii").strip()
def write_version(version_core, pre_release=True):
if pre_release:
time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
@ -23,7 +22,6 @@ def write_version(version_core, pre_release=True):
return version
with open("README.md", "r") as f:
long_description = f.read()
@ -37,31 +35,60 @@ setup(
long_description=long_description,
long_description_content_type="text/markdown",
packages=find_packages(),
install_requires=(["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else []) +[
install_requires=(
# training backends
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else [])
+ [
# logging niceties
"coloredlogs>=15.0.1",
"humanize>=4.4.0",
"matplotlib>=3.6.0",
"pandas>=1.5.0",
# boiler plate niceties
"diskcache>=5.4.0",
"einops>=0.6.0",
"encodec>=0.1.1",
"phonemizer>=2.1.0",
"matplotlib>=3.6.0",
"numpy",
"omegaconf==2.0.6",
"tqdm>=4.64.1",
"humanize>=4.4.0",
"tqdm",
# HF bloat
"tokenizers>4.37.0",
"transformers>4.37.0",
"pandas>=1.5.0",
# training bloat
"auraloss[all]", # [all] is needed for MelSTFTLoss
"h5py",
"prodigyopt @ git+https://github.com/konstmish/prodigy",
# practically the reason to use python
"numpy",
"torch>=1.13.0",
"torchaudio>=0.13.0",
"torchmetrics",
"auraloss[all]",
# core foundations
"phonemizer>=2.1.0",
"encodec>=0.1.1",
"vocos",
"h5py",
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
"prodigyopt @ git+https://github.com/konstmish/prodigy",
"descript-audio-codec",
# gradio web UI
"gradio"
],
extras_require = {
"all": [
# retnet backend (even though two internal copies exist)
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
# bitnet
"bitnet",
# mamba
"causal-conv1d",
"mamba-ssm",
# attention helpers
"xformers",
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
]
},
url="https://git.ecker.tech/mrq/vall-e",
)

View File

@ -6,6 +6,8 @@ import os
import subprocess
import sys
import time
import argparse
import yaml
import torch
@ -14,15 +16,13 @@ from dataclasses import asdict, dataclass, field
from functools import cached_property
from pathlib import Path
from omegaconf import OmegaConf
from .utils.distributed import world_size
# Yuck
from transformers import PreTrainedTokenizerFast
@dataclass()
class _Config:
class BaseConfig:
cfg_path: str | None = None
@property
@ -81,39 +81,29 @@ class _Config:
with open(path, "w") as f:
f.write(self.dumps())
@staticmethod
def _is_cfg_argv(s):
return "=" in s and "--" not in s
@classmethod
def from_yaml( cls, yaml_path ):
return cls.from_cli( [f'yaml="{yaml_path}"'] )
return cls.from_cli( [f'--yaml="{yaml_path}"'] )
@classmethod
def from_cli(cls, args=sys.argv):
cli_cfg = OmegaConf.from_cli([s for s in args if cls._is_cfg_argv(s)])
# legacy support for yaml=`` format
for i, arg in enumerate(args):
if arg.startswith("yaml"):
args[i] = f'--{arg}'
# Replace argv to ensure there are no omegaconf options, for compatibility with argparse.
sys.argv = [s for s in sys.argv if not cls._is_cfg_argv(s)]
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args)
if cli_cfg.get("help"):
print(f"Configurable hyperparameters with their default values:")
print(json.dumps(asdict(cls()), indent=2, default=str))
exit()
state = {}
if args.yaml:
cfg_path = args.yaml
state = yaml.safe_load(open(cfg_path, "r", encoding="utf-8"))
state.setdefault("cfg_path", cfg_path)
if "yaml" in cli_cfg:
yaml_cfg = OmegaConf.load(cli_cfg.yaml)
yaml_path = Path(cli_cfg.yaml).absolute()
cfg_path = Path(*yaml_path.relative_to(Path.cwd()).parts[:-1])
cfg_path = cfg_path.with_suffix("")
cfg_path = f'./{cfg_path}'
yaml_cfg.setdefault("cfg_path", cfg_path)
cli_cfg.pop("yaml")
else:
yaml_cfg = {}
merged = OmegaConf.merge(yaml_cfg, cli_cfg)
return cls(**dict(merged))
return cls(**state)
def __repr__(self):
return str(self)
@ -621,7 +611,7 @@ class Optimizations:
fp8: bool = False # use fp8
@dataclass()
class Config(_Config):
class Config(BaseConfig):
device: str = "cuda"
mode: str = "training" # "inferencing"
experimental: bool = False # So I can stop commenting out things when committing
@ -668,6 +658,7 @@ class Config(_Config):
return diskcache.Cache(self.cache_dir).memoize
return lambda: lambda x: x
# I don't remember why this is needed
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
@ -759,6 +750,10 @@ class Config(_Config):
if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
# load our HDF5 file if requested here
if self.dataset.use_hdf5:
self.load_hdf5()
# Preserves the old behavior
class NaiveTokenizer:
def get_vocab( self ):
@ -787,15 +782,12 @@ class NaiveTokenizer:
cfg = Config.from_cli()
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
try:
cfg.format()
if cfg.dataset.use_hdf5:
cfg.load_hdf5()
except Exception as e:
cfg.dataset.use_hdf5 = False
print("Error while parsing config YAML:", e)
pass
print("Error while parsing config YAML:")
raise e # throw an error because I'm tired of silent errors messing things up for me
try:
from transformers import PreTrainedTokenizerFast

View File

@ -32,7 +32,8 @@ class TTS():
try:
cfg.format()
except Exception as e:
pass
print("Error while parsing config YAML:")
raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None:
amp = cfg.inference.amp