distributed training works now (hopefully)

This commit is contained in:
mrq 2023-08-13 22:07:45 -05:00
parent 2af09d0bef
commit d7deaf6def
13 changed files with 73 additions and 58 deletions

View File

@ -8,8 +8,6 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
> **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed. This README won't get much love until I truly nail out a quasi-decent model.
> **Note** Distributed training seems broken? I'm not really sure how to test it, as my two 6800XTs have been redistributed for now, and the last time I tried using them for this, things weren't good.
> **Note** You can follow along with my pseudo-blog in an issue [here](https://git.ecker.tech/mrq/ai-voice-cloning/issues/152). I currently have a dataset clocking in at 3400+ trimmed hours.
### Requirements
@ -49,6 +47,8 @@ Training is very dependent on:
#### Leverage Your Own
> **Note** It is highly recommended to utilize [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) with `--tts-backend="vall-e"` to handle transcription and dataset preparations.
1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.txt`.
2. Quantize the data:
@ -66,6 +66,8 @@ python -m vall_e.emb.g2p ./data/custom
4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
> **Note** Be sure to set `distributd: True` to ensure the `DistributedSampler` is used. In the future, I'll have it automagically detect this.
If you're interested in creating an HDF5 copy of your dataset, simply invoke:
```

View File

@ -1,10 +1,8 @@
dataset:
training: [
]
validation: [
]
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
@ -15,30 +13,30 @@ dataset:
workers: 8
cache: True
phones_range: [4, 512]
duration_range: [1.0, 24.0]
phones_range: [4, 128]
duration_range: [1.0, 12.0]
random_utterance: 1.0
max_prompts: 6
prompt_duration: 6.0
max_prompts: 3
prompt_duration: 3.0
models:
_models:
- name: "ar"
size: "quarter"
size: "full"
resp_levels: 1
arch_type: "retnet"
- name: "nar"
size: "quarter"
size: "full"
resp_levels: 1
arch_type: "retnet"
prom_levels: 2
hyperparameters:
batch_size: 32
gradient_accumulation_steps: 4
batch_size: 8
gradient_accumulation_steps: 16
gradient_clipping: 100
optimizer: Adamw
@ -90,15 +88,18 @@ trainer:
gc_mode: None # "global_step"
weight_dtype: bfloat16
weight_dtype: bfloat16 # float16, float32
backend: deepspeed
deepspeed:
zero_optimization_level: 0
zero_optimization_level: 2
use_compression_training: True
inference:
use_vocos: True
bitsandbytes:
enabled: false
enabled: false
device: cuda
distributed: False

View File

@ -17,7 +17,7 @@ def main():
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_samples=args.max_ar_samples, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
if __name__ == "__main__":
main()

View File

@ -7,6 +7,8 @@ import subprocess
import sys
import time
import torch
from dataclasses import asdict, dataclass
from dataclasses import dataclass, field
@ -379,7 +381,7 @@ class Trainer:
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if cfg.trainer.weight_dtype == "bfloat16":
if self.weight_dtype == "bfloat16":
return torch.bfloat16
return torch.float32
@ -399,6 +401,7 @@ class BitsAndBytes:
@dataclass()
class Config(_Config):
device: str = "cuda"
distributed: bool = False
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
@ -433,21 +436,24 @@ class Config(_Config):
cfg = Config.from_cli()
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves
cfg.dataset = Dataset(**cfg.dataset)
cfg.models = Models(**cfg.models)
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
cfg.evaluation = Evaluation(**cfg.evaluation)
cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
try:
cfg.dataset = Dataset(**cfg.dataset)
cfg.models = Models(**cfg.models)
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
cfg.evaluation = Evaluation(**cfg.evaluation)
cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
except Exception as e:
pass
# cached_property stopped working...
if cfg.dataset.use_hdf5:
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
@ -457,4 +463,4 @@ if not cfg.dataset.use_hdf5:
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
if __name__ == "__main__":
print(cfg)
print(cfg)

View File

@ -20,6 +20,7 @@ from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
@ -312,13 +313,14 @@ def _create_dataloader(dataset, training):
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=True, # training
shuffle=False if cfg.distributed else True, # training
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=True,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler
)
def _load_dataset_paths():

View File

@ -169,7 +169,8 @@ def encode_from_file(path, device="cuda"):
if isinstance( path, list ):
return encode_from_files( path, device )
else:
wav, sr = torchaudio.load(str(path), format=path[-3:])
path = str(path)
wav, sr = torchaudio.load(path, format=path[-3:])
if wav.shape[0] == 2:
wav = wav[:1]

View File

@ -275,7 +275,7 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()):
def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0
stats: Any = dict()
@ -283,10 +283,9 @@ class Engines(dict[str, Engine]):
if cfg.trainer.gc_mode == 'step':
do_gc()
batch = to_device(batch, device)
for name, engine in self.items():
#torch.cuda.synchronize()
device = engine.device
if cfg.trainer.gc_mode == 'substep':
do_gc()
@ -294,10 +293,9 @@ class Engines(dict[str, Engine]):
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=cfg.device)
n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, device)
batch = to_device(batch, device)
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch )
@ -336,7 +334,7 @@ class Engines(dict[str, Engine]):
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=cfg.device)
n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')

View File

@ -27,7 +27,9 @@ def main():
model = models[name]
outpath = f'{args.path}/{name}.pt'
torch.save(model, outpath)
torch.save({
'module': model.state_dict()
}, outpath)
print(f"Exported {name} to {outpath}")
if __name__ == "__main__":

View File

@ -9,6 +9,8 @@ from .utils import to_device
from .config import cfg
from .export import load_models
from .models import get_models
from .data import get_phone_symmap
class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
@ -19,11 +21,21 @@ class TTS():
self.output_sample_rate = 24000
if ar_ckpt and nar_ckpt:
self.load_ar( ar_ckpt )
self.load_nar( nar_ckpt )
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get())
for name, model in models.items():
if name.startswith("ar"):
self.ar = model.to(self.device, dtype=torch.float32)
self.ar.load_state_dict(torch.load(self.ar_ckpt)['module'])
elif name.startswith("nar"):
self.nar = model.to(self.device, dtype=torch.float32)
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
else:
self.load_models( config )
self.symmap = get_phone_symmap()
self.ar.eval()
self.nar.eval()
@ -39,24 +51,13 @@ class TTS():
for name in models:
model = models[name]
if name[:2] == "ar":
self.ar = model.to(self.device)
self.ar = model.to(self.device, dtype=torch.float32)
self.symmap = self.ar.phone_symmap
elif name[:3] == "nar":
self.nar = model.to(self.device)
self.nar = model.to(self.device, dtype=torch.float32)
else:
print("Unknown:", name)
def load_ar( self, ckpt ):
self.ar_ckpt = ckpt
self.ar = torch.load(self.ar_ckpt).to(self.device)
self.symmap = self.ar.phone_symmap
def load_nar( self, ckpt ):
self.nar_ckpt = nar_ckpt
self.nar = torch.load(self.nar_ckpt).to(self.device)
def encode_text( self, text, lang_marker="en" ):
text = g2p.encode(text)
phones = [f"<{lang_marker}>"] + [ " " if not p else p for p in text ] + [f"</{lang_marker}>"]
@ -70,7 +71,7 @@ class TTS():
def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
prom = self.encode_audio( reference )
phns = self.encode_text(text)
phns = self.encode_text( text )
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)

View File

@ -114,7 +114,7 @@ def example_usage():
from ..engines import Engine
from tqdm import tqdm
device = "cpu"
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):

View File

@ -272,6 +272,7 @@ class Base(nn.Module):
"""
batch_size = len(text_list)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
@ -281,6 +282,7 @@ class Base(nn.Module):
x, m = list_to_tensor(x_list)
if self.arch_type == "transformer":
x = self.sin_emb.add_pe(x)
for block in self.blocks:

View File

@ -19,7 +19,7 @@ from collections import defaultdict
from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda")
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
def center_crop(x, len):
start = (x.shape[-1] - len) // 2
@ -89,10 +89,10 @@ def run_eval(engines, eval_name, dl):
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
except Exception as e:
stats['loss'].append(0)
print(str(e))
for batch in tqdm(dl):

View File

@ -80,7 +80,7 @@ def load_engines():
if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path))
model.load_state_dict(torch.load(load_path)['module'])
engines[name] = Engine(
model=model,