This commit is contained in:
mrq 2023-08-03 20:36:19 -05:00
parent c85101403f
commit 608c1970eb
5 changed files with 65 additions and 49 deletions

View File

@ -6,9 +6,7 @@
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), based on the [EnCodec](https://github.com/facebookresearch/encodec) tokenizer.
> **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed.
> **Note** This README won't get much love until I truly nail out a quasi-decent model.
> **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed. This README won't get much love until I truly nail out a quasi-decent model.
> **Note** Distributed training seems broken? I'm not really sure how to test it, as my two 6800XTs have been redistributed for now, and the last time I tried using them for this, things weren't good.
@ -16,7 +14,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
### Requirements
Since the trainer is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed#requirements), you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
If your config YAML has the training backend set to [`deepspeed`](https://github.com/microsoft/DeepSpeed#requirements), you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
### Install
@ -31,7 +29,16 @@ git clone --recurse-submodules https://git.ecker.tech/mrq/vall-e.git
```
Note that the code is only tested under `Python 3.10.9`.
* `fairseq` is not compatible with `Python 3.11`, a pseudo-dependency for `torchscale`.
### Try Me
To quickly try it out, you can choose between the following modes:
* AR only: `python -m vall_e.models.ar yaml="./data/config.yaml"`
* NAR only: `python -m vall_e.models.nar yaml="./data/config.yaml"`
* AR+NAR: `python -m vall_e.models.base yaml="./data/config.yaml"`
Each model file has a barebones trainer and inference routine.
### Train
@ -42,7 +49,7 @@ Training is very dependent on:
#### Leverage Your Own
1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.normalized.txt`.
1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.txt`.
2. Quantize the data:
@ -56,7 +63,14 @@ python -m vall_e.emb.qnt ./data/custom
python -m vall_e.emb.g2p ./data/custom
```
4. Customize your configuration modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
If you're interested in creating an HDF5 copy of your dataset, simply invoke:
```
python -m vall_e.data yaml='./data/config.yaml'
```
5. Train the AR and NAR models using the following scripts:

View File

@ -365,6 +365,7 @@ class Trainer:
restart_step_count: bool = False
aggressive_optimizations: bool = False
check_for_oom: bool = True
gc_mode: str | None = None

View File

@ -59,7 +59,7 @@ def encode(text: str, language="en-us", backend="espeak") -> list[str]:
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", type=str, default=".normalized.txt")
parser.add_argument("--suffix", type=str, default=".txt")
args = parser.parse_args()
paths = list(args.folder.rglob(f"*{args.suffix}"))

View File

@ -281,36 +281,36 @@ class Engines(dict[str, Engine]):
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, device)
res = feeder( engine=engine, batch=batch )
"""
while tries >= 0:
try:
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch )
else:
while tries >= 0:
try:
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
"""
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
if res is None:
continue
@ -323,24 +323,24 @@ class Engines(dict[str, Engine]):
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
engine.backward(loss)
"""
try:
if not cfg.trainer.check_for_oom:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
else:
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise e
n_ooms += 1
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
"""
raise RuntimeError("Out of memory during backwards pass!")
engine.step()

View File

@ -372,6 +372,7 @@ class Base(nn.Module):
def example_usage():
from ..config import cfg
cfg.trainer.backend = "local"
cfg.trainer.check_for_oom = False
from functools import partial