ops
This commit is contained in:
parent
c85101403f
commit
608c1970eb
28
README.md
28
README.md
|
@ -6,9 +6,7 @@
|
|||
|
||||
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), based on the [EnCodec](https://github.com/facebookresearch/encodec) tokenizer.
|
||||
|
||||
> **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed.
|
||||
|
||||
> **Note** This README won't get much love until I truly nail out a quasi-decent model.
|
||||
> **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed. This README won't get much love until I truly nail out a quasi-decent model.
|
||||
|
||||
> **Note** Distributed training seems broken? I'm not really sure how to test it, as my two 6800XTs have been redistributed for now, and the last time I tried using them for this, things weren't good.
|
||||
|
||||
|
@ -16,7 +14,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
|||
|
||||
### Requirements
|
||||
|
||||
Since the trainer is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed#requirements), you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
|
||||
If your config YAML has the training backend set to [`deepspeed`](https://github.com/microsoft/DeepSpeed#requirements), you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
|
||||
|
||||
### Install
|
||||
|
||||
|
@ -31,7 +29,16 @@ git clone --recurse-submodules https://git.ecker.tech/mrq/vall-e.git
|
|||
```
|
||||
|
||||
Note that the code is only tested under `Python 3.10.9`.
|
||||
* `fairseq` is not compatible with `Python 3.11`, a pseudo-dependency for `torchscale`.
|
||||
|
||||
### Try Me
|
||||
|
||||
To quickly try it out, you can choose between the following modes:
|
||||
|
||||
* AR only: `python -m vall_e.models.ar yaml="./data/config.yaml"`
|
||||
* NAR only: `python -m vall_e.models.nar yaml="./data/config.yaml"`
|
||||
* AR+NAR: `python -m vall_e.models.base yaml="./data/config.yaml"`
|
||||
|
||||
Each model file has a barebones trainer and inference routine.
|
||||
|
||||
### Train
|
||||
|
||||
|
@ -42,7 +49,7 @@ Training is very dependent on:
|
|||
|
||||
#### Leverage Your Own
|
||||
|
||||
1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.normalized.txt`.
|
||||
1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.txt`.
|
||||
|
||||
2. Quantize the data:
|
||||
|
||||
|
@ -56,7 +63,14 @@ python -m vall_e.emb.qnt ./data/custom
|
|||
python -m vall_e.emb.g2p ./data/custom
|
||||
```
|
||||
|
||||
4. Customize your configuration modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
||||
|
||||
4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
||||
|
||||
If you're interested in creating an HDF5 copy of your dataset, simply invoke:
|
||||
|
||||
```
|
||||
python -m vall_e.data yaml='./data/config.yaml'
|
||||
```
|
||||
|
||||
5. Train the AR and NAR models using the following scripts:
|
||||
|
||||
|
|
|
@ -365,6 +365,7 @@ class Trainer:
|
|||
restart_step_count: bool = False
|
||||
|
||||
aggressive_optimizations: bool = False
|
||||
check_for_oom: bool = True
|
||||
|
||||
gc_mode: str | None = None
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ def encode(text: str, language="en-us", backend="espeak") -> list[str]:
|
|||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", type=Path)
|
||||
parser.add_argument("--suffix", type=str, default=".normalized.txt")
|
||||
parser.add_argument("--suffix", type=str, default=".txt")
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
||||
|
|
|
@ -281,36 +281,36 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, device)
|
||||
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
"""
|
||||
while tries >= 0:
|
||||
try:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
break
|
||||
except RuntimeError as e:
|
||||
print("Forward", str(e))
|
||||
if not cfg.trainer.check_for_oom:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
else:
|
||||
while tries >= 0:
|
||||
try:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
break
|
||||
except RuntimeError as e:
|
||||
print("Forward", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
# shrink batch size until it's happy
|
||||
for k in batch:
|
||||
batch[k] = batch[k][:-1]
|
||||
# shrink batch size until it's happy
|
||||
for k in batch:
|
||||
batch[k] = batch[k][:-1]
|
||||
|
||||
if tries <= 0:
|
||||
# trigger OOM
|
||||
n_ooms += 1
|
||||
else:
|
||||
# also do GC
|
||||
do_gc()
|
||||
continue
|
||||
if tries <= 0:
|
||||
# trigger OOM
|
||||
n_ooms += 1
|
||||
else:
|
||||
# also do GC
|
||||
do_gc()
|
||||
continue
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during forward pass!")
|
||||
"""
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during forward pass!")
|
||||
|
||||
if res is None:
|
||||
continue
|
||||
|
@ -323,24 +323,24 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, 'cpu')
|
||||
|
||||
engine.backward(loss)
|
||||
"""
|
||||
try:
|
||||
if not cfg.trainer.check_for_oom:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
print("Backwards:", str(e))
|
||||
else:
|
||||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
print("Backwards:", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
n_ooms += 1
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
n_ooms += 1
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
"""
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
|
||||
engine.step()
|
||||
|
||||
|
|
|
@ -372,6 +372,7 @@ class Base(nn.Module):
|
|||
def example_usage():
|
||||
from ..config import cfg
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.trainer.check_for_oom = False
|
||||
|
||||
from functools import partial
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user