distributed training works now (hopefully)
This commit is contained in:
parent
2af09d0bef
commit
d7deaf6def
|
@ -8,8 +8,6 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
|||
|
||||
> **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed. This README won't get much love until I truly nail out a quasi-decent model.
|
||||
|
||||
> **Note** Distributed training seems broken? I'm not really sure how to test it, as my two 6800XTs have been redistributed for now, and the last time I tried using them for this, things weren't good.
|
||||
|
||||
> **Note** You can follow along with my pseudo-blog in an issue [here](https://git.ecker.tech/mrq/ai-voice-cloning/issues/152). I currently have a dataset clocking in at 3400+ trimmed hours.
|
||||
|
||||
### Requirements
|
||||
|
@ -49,6 +47,8 @@ Training is very dependent on:
|
|||
|
||||
#### Leverage Your Own
|
||||
|
||||
> **Note** It is highly recommended to utilize [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) with `--tts-backend="vall-e"` to handle transcription and dataset preparations.
|
||||
|
||||
1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.txt`.
|
||||
|
||||
2. Quantize the data:
|
||||
|
@ -66,6 +66,8 @@ python -m vall_e.emb.g2p ./data/custom
|
|||
|
||||
4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
||||
|
||||
> **Note** Be sure to set `distributd: True` to ensure the `DistributedSampler` is used. In the future, I'll have it automagically detect this.
|
||||
|
||||
If you're interested in creating an HDF5 copy of your dataset, simply invoke:
|
||||
|
||||
```
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
dataset:
|
||||
training: [
|
||||
|
||||
]
|
||||
|
||||
validation: [
|
||||
|
||||
]
|
||||
|
||||
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
|
@ -15,30 +13,30 @@ dataset:
|
|||
workers: 8
|
||||
cache: True
|
||||
|
||||
phones_range: [4, 512]
|
||||
duration_range: [1.0, 24.0]
|
||||
phones_range: [4, 128]
|
||||
duration_range: [1.0, 12.0]
|
||||
|
||||
random_utterance: 1.0
|
||||
max_prompts: 6
|
||||
prompt_duration: 6.0
|
||||
max_prompts: 3
|
||||
prompt_duration: 3.0
|
||||
|
||||
models:
|
||||
_models:
|
||||
- name: "ar"
|
||||
size: "quarter"
|
||||
size: "full"
|
||||
resp_levels: 1
|
||||
arch_type: "retnet"
|
||||
|
||||
- name: "nar"
|
||||
size: "quarter"
|
||||
size: "full"
|
||||
resp_levels: 1
|
||||
arch_type: "retnet"
|
||||
|
||||
prom_levels: 2
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 32
|
||||
gradient_accumulation_steps: 4
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_clipping: 100
|
||||
|
||||
optimizer: Adamw
|
||||
|
@ -90,15 +88,18 @@ trainer:
|
|||
|
||||
gc_mode: None # "global_step"
|
||||
|
||||
weight_dtype: bfloat16
|
||||
weight_dtype: bfloat16 # float16, float32
|
||||
|
||||
backend: deepspeed
|
||||
deepspeed:
|
||||
zero_optimization_level: 0
|
||||
zero_optimization_level: 2
|
||||
use_compression_training: True
|
||||
|
||||
inference:
|
||||
use_vocos: True
|
||||
|
||||
bitsandbytes:
|
||||
enabled: false
|
||||
enabled: false
|
||||
|
||||
device: cuda
|
||||
distributed: False
|
|
@ -17,7 +17,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
|
||||
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_samples=args.max_ar_samples, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
|
||||
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -7,6 +7,8 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -379,7 +381,7 @@ class Trainer:
|
|||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if cfg.trainer.weight_dtype == "bfloat16":
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
|
@ -399,6 +401,7 @@ class BitsAndBytes:
|
|||
@dataclass()
|
||||
class Config(_Config):
|
||||
device: str = "cuda"
|
||||
distributed: bool = False
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
models: Models = field(default_factory=lambda: Models)
|
||||
|
@ -433,21 +436,24 @@ class Config(_Config):
|
|||
|
||||
cfg = Config.from_cli()
|
||||
|
||||
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves
|
||||
cfg.dataset = Dataset(**cfg.dataset)
|
||||
cfg.models = Models(**cfg.models)
|
||||
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
|
||||
cfg.evaluation = Evaluation(**cfg.evaluation)
|
||||
cfg.trainer = Trainer(**cfg.trainer)
|
||||
cfg.inference = Inference(**cfg.inference)
|
||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
|
||||
try:
|
||||
cfg.dataset = Dataset(**cfg.dataset)
|
||||
cfg.models = Models(**cfg.models)
|
||||
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
|
||||
cfg.evaluation = Evaluation(**cfg.evaluation)
|
||||
cfg.trainer = Trainer(**cfg.trainer)
|
||||
cfg.inference = Inference(**cfg.inference)
|
||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
|
||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a')
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
@ -457,4 +463,4 @@ if not cfg.dataset.use_hdf5:
|
|||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(cfg)
|
||||
print(cfg)
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Any
|
|||
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset as _Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
@ -312,13 +313,14 @@ def _create_dataloader(dataset, training):
|
|||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
shuffle=True, # training
|
||||
shuffle=False if cfg.distributed else True, # training
|
||||
drop_last=training,
|
||||
num_workers=cfg.dataset.workers,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=True,
|
||||
pin_memory=False, # True,
|
||||
worker_init_fn=_seed_worker,
|
||||
sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler
|
||||
)
|
||||
|
||||
def _load_dataset_paths():
|
||||
|
|
|
@ -169,7 +169,8 @@ def encode_from_file(path, device="cuda"):
|
|||
if isinstance( path, list ):
|
||||
return encode_from_files( path, device )
|
||||
else:
|
||||
wav, sr = torchaudio.load(str(path), format=path[-3:])
|
||||
path = str(path)
|
||||
wav, sr = torchaudio.load(path, format=path[-3:])
|
||||
|
||||
if wav.shape[0] == 2:
|
||||
wav = wav[:1]
|
||||
|
|
|
@ -275,7 +275,7 @@ class Engines(dict[str, Engine]):
|
|||
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
||||
return stats
|
||||
|
||||
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()):
|
||||
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
||||
total_elapsed_time = 0
|
||||
|
||||
stats: Any = dict()
|
||||
|
@ -283,10 +283,9 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.trainer.gc_mode == 'step':
|
||||
do_gc()
|
||||
|
||||
batch = to_device(batch, device)
|
||||
|
||||
for name, engine in self.items():
|
||||
#torch.cuda.synchronize()
|
||||
device = engine.device
|
||||
|
||||
if cfg.trainer.gc_mode == 'substep':
|
||||
do_gc()
|
||||
|
@ -294,10 +293,9 @@ class Engines(dict[str, Engine]):
|
|||
start_time = time.time()
|
||||
|
||||
tries = 4
|
||||
n_ooms = torch.zeros([], device=cfg.device)
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, device)
|
||||
batch = to_device(batch, device)
|
||||
|
||||
if not cfg.trainer.check_for_oom:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
|
@ -336,7 +334,7 @@ class Engines(dict[str, Engine]):
|
|||
loss, engine_stats = res
|
||||
engine_stats |= self.gather_attribute("scalar")
|
||||
|
||||
n_ooms = torch.zeros([], device=cfg.device)
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, 'cpu')
|
||||
|
|
|
@ -27,7 +27,9 @@ def main():
|
|||
model = models[name]
|
||||
|
||||
outpath = f'{args.path}/{name}.pt'
|
||||
torch.save(model, outpath)
|
||||
torch.save({
|
||||
'module': model.state_dict()
|
||||
}, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -9,6 +9,8 @@ from .utils import to_device
|
|||
|
||||
from .config import cfg
|
||||
from .export import load_models
|
||||
from .models import get_models
|
||||
from .data import get_phone_symmap
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
||||
|
@ -19,11 +21,21 @@ class TTS():
|
|||
self.output_sample_rate = 24000
|
||||
|
||||
if ar_ckpt and nar_ckpt:
|
||||
self.load_ar( ar_ckpt )
|
||||
self.load_nar( nar_ckpt )
|
||||
self.ar_ckpt = ar_ckpt
|
||||
self.nar_ckpt = nar_ckpt
|
||||
|
||||
models = get_models(cfg.models.get())
|
||||
for name, model in models.items():
|
||||
if name.startswith("ar"):
|
||||
self.ar = model.to(self.device, dtype=torch.float32)
|
||||
self.ar.load_state_dict(torch.load(self.ar_ckpt)['module'])
|
||||
elif name.startswith("nar"):
|
||||
self.nar = model.to(self.device, dtype=torch.float32)
|
||||
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
|
||||
else:
|
||||
self.load_models( config )
|
||||
|
||||
self.symmap = get_phone_symmap()
|
||||
self.ar.eval()
|
||||
self.nar.eval()
|
||||
|
||||
|
@ -39,24 +51,13 @@ class TTS():
|
|||
for name in models:
|
||||
model = models[name]
|
||||
if name[:2] == "ar":
|
||||
self.ar = model.to(self.device)
|
||||
self.ar = model.to(self.device, dtype=torch.float32)
|
||||
self.symmap = self.ar.phone_symmap
|
||||
elif name[:3] == "nar":
|
||||
self.nar = model.to(self.device)
|
||||
self.nar = model.to(self.device, dtype=torch.float32)
|
||||
else:
|
||||
print("Unknown:", name)
|
||||
|
||||
def load_ar( self, ckpt ):
|
||||
self.ar_ckpt = ckpt
|
||||
|
||||
self.ar = torch.load(self.ar_ckpt).to(self.device)
|
||||
self.symmap = self.ar.phone_symmap
|
||||
|
||||
def load_nar( self, ckpt ):
|
||||
self.nar_ckpt = nar_ckpt
|
||||
|
||||
self.nar = torch.load(self.nar_ckpt).to(self.device)
|
||||
|
||||
def encode_text( self, text, lang_marker="en" ):
|
||||
text = g2p.encode(text)
|
||||
phones = [f"<{lang_marker}>"] + [ " " if not p else p for p in text ] + [f"</{lang_marker}>"]
|
||||
|
@ -70,7 +71,7 @@ class TTS():
|
|||
|
||||
def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
|
||||
prom = self.encode_audio( reference )
|
||||
phns = self.encode_text(text)
|
||||
phns = self.encode_text( text )
|
||||
|
||||
prom = to_device(prom, self.device).to(torch.int16)
|
||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
|
|
|
@ -114,7 +114,7 @@ def example_usage():
|
|||
from ..engines import Engine
|
||||
from tqdm import tqdm
|
||||
|
||||
device = "cpu"
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
def tokenize(content, lang_marker="en"):
|
||||
|
|
|
@ -272,6 +272,7 @@ class Base(nn.Module):
|
|||
"""
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
|
@ -281,6 +282,7 @@ class Base(nn.Module):
|
|||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
x = self.sin_emb.add_pe(x)
|
||||
for block in self.blocks:
|
||||
|
|
|
@ -19,7 +19,7 @@ from collections import defaultdict
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda")
|
||||
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
||||
|
||||
def center_crop(x, len):
|
||||
start = (x.shape[-1] - len) // 2
|
||||
|
@ -89,10 +89,10 @@ def run_eval(engines, eval_name, dl):
|
|||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
|
||||
try:
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||||
except Exception as e:
|
||||
stats['loss'].append(0)
|
||||
print(str(e))
|
||||
|
||||
for batch in tqdm(dl):
|
||||
|
|
|
@ -80,7 +80,7 @@ def load_engines():
|
|||
|
||||
if cfg.trainer.load_state_dict:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
model.load_state_dict(torch.load(load_path))
|
||||
model.load_state_dict(torch.load(load_path)['module'])
|
||||
|
||||
engines[name] = Engine(
|
||||
model=model,
|
||||
|
|
Loading…
Reference in New Issue
Block a user