diff --git a/README.md b/README.md index 6a8c509..6ccec8d 100755 --- a/README.md +++ b/README.md @@ -8,8 +8,6 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), > **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed. This README won't get much love until I truly nail out a quasi-decent model. -> **Note** Distributed training seems broken? I'm not really sure how to test it, as my two 6800XTs have been redistributed for now, and the last time I tried using them for this, things weren't good. - > **Note** You can follow along with my pseudo-blog in an issue [here](https://git.ecker.tech/mrq/ai-voice-cloning/issues/152). I currently have a dataset clocking in at 3400+ trimmed hours. ### Requirements @@ -49,6 +47,8 @@ Training is very dependent on: #### Leverage Your Own +> **Note** It is highly recommended to utilize [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) with `--tts-backend="vall-e"` to handle transcription and dataset preparations. + 1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.txt`. 2. Quantize the data: @@ -66,6 +66,8 @@ python -m vall_e.emb.g2p ./data/custom 4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`. +> **Note** Be sure to set `distributd: True` to ensure the `DistributedSampler` is used. In the future, I'll have it automagically detect this. + If you're interested in creating an HDF5 copy of your dataset, simply invoke: ``` diff --git a/data/config.yaml b/data/config.yaml index 93e14be..a483998 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,10 +1,8 @@ dataset: training: [ - ] validation: [ - ] speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" @@ -15,30 +13,30 @@ dataset: workers: 8 cache: True - phones_range: [4, 512] - duration_range: [1.0, 24.0] + phones_range: [4, 128] + duration_range: [1.0, 12.0] random_utterance: 1.0 - max_prompts: 6 - prompt_duration: 6.0 + max_prompts: 3 + prompt_duration: 3.0 models: _models: - name: "ar" - size: "quarter" + size: "full" resp_levels: 1 arch_type: "retnet" - name: "nar" - size: "quarter" + size: "full" resp_levels: 1 arch_type: "retnet" prom_levels: 2 hyperparameters: - batch_size: 32 - gradient_accumulation_steps: 4 + batch_size: 8 + gradient_accumulation_steps: 16 gradient_clipping: 100 optimizer: Adamw @@ -90,15 +88,18 @@ trainer: gc_mode: None # "global_step" - weight_dtype: bfloat16 + weight_dtype: bfloat16 # float16, float32 backend: deepspeed deepspeed: - zero_optimization_level: 0 + zero_optimization_level: 2 use_compression_training: True inference: use_vocos: True bitsandbytes: - enabled: false \ No newline at end of file + enabled: false + +device: cuda +distributed: False \ No newline at end of file diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 113cd2d..91673a1 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -17,7 +17,7 @@ def main(): args = parser.parse_args() tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device ) - tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_samples=args.max_ar_samples, ar_temp=args.ar_temp, nar_temp=args.nar_temp ) + tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp ) if __name__ == "__main__": main() diff --git a/vall_e/config.py b/vall_e/config.py index 9813164..442b5f3 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -7,6 +7,8 @@ import subprocess import sys import time +import torch + from dataclasses import asdict, dataclass from dataclasses import dataclass, field @@ -379,7 +381,7 @@ class Trainer: def dtype(self): if self.weight_dtype == "float16": return torch.float16 - if cfg.trainer.weight_dtype == "bfloat16": + if self.weight_dtype == "bfloat16": return torch.bfloat16 return torch.float32 @@ -399,6 +401,7 @@ class BitsAndBytes: @dataclass() class Config(_Config): device: str = "cuda" + distributed: bool = False dataset: Dataset = field(default_factory=lambda: Dataset) models: Models = field(default_factory=lambda: Models) @@ -433,21 +436,24 @@ class Config(_Config): cfg = Config.from_cli() -# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves -cfg.dataset = Dataset(**cfg.dataset) -cfg.models = Models(**cfg.models) -cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters) -cfg.evaluation = Evaluation(**cfg.evaluation) -cfg.trainer = Trainer(**cfg.trainer) -cfg.inference = Inference(**cfg.inference) -cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) +# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves +try: + cfg.dataset = Dataset(**cfg.dataset) + cfg.models = Models(**cfg.models) + cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters) + cfg.evaluation = Evaluation(**cfg.evaluation) + cfg.trainer = Trainer(**cfg.trainer) + cfg.inference = Inference(**cfg.inference) + cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) -cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed) + cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed) +except Exception as e: + pass # cached_property stopped working... if cfg.dataset.use_hdf5: try: - cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a') + cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') except Exception as e: print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) cfg.dataset.use_hdf5 = False @@ -457,4 +463,4 @@ if not cfg.dataset.use_hdf5: cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] if __name__ == "__main__": - print(cfg) \ No newline at end of file + print(cfg) diff --git a/vall_e/data.py b/vall_e/data.py index 53388f1..57a9c99 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -20,6 +20,7 @@ from typing import Any from torch import Tensor from torch.utils.data import DataLoader, Dataset as _Dataset +from torch.utils.data.distributed import DistributedSampler from tqdm.auto import tqdm # torch.multiprocessing.set_sharing_strategy("file_system") @@ -312,13 +313,14 @@ def _create_dataloader(dataset, training): return DataLoader( dataset=dataset, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, - shuffle=True, # training + shuffle=False if cfg.distributed else True, # training drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=True, pin_memory=False, # True, worker_init_fn=_seed_worker, + sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler ) def _load_dataset_paths(): diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 1900d89..dba7938 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -169,7 +169,8 @@ def encode_from_file(path, device="cuda"): if isinstance( path, list ): return encode_from_files( path, device ) else: - wav, sr = torchaudio.load(str(path), format=path[-3:]) + path = str(path) + wav, sr = torchaudio.load(path, format=path[-3:]) if wav.shape[0] == 2: wav = wav[:1] diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 8b9dc04..fdf4fae 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -275,7 +275,7 @@ class Engines(dict[str, Engine]): stats.update(flatten_dict({ name.split("-")[0]: stat })) return stats - def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()): + def step(self, batch, feeder: TrainFeeder = default_feeder): total_elapsed_time = 0 stats: Any = dict() @@ -283,10 +283,9 @@ class Engines(dict[str, Engine]): if cfg.trainer.gc_mode == 'step': do_gc() - batch = to_device(batch, device) for name, engine in self.items(): - #torch.cuda.synchronize() + device = engine.device if cfg.trainer.gc_mode == 'substep': do_gc() @@ -294,10 +293,9 @@ class Engines(dict[str, Engine]): start_time = time.time() tries = 4 - n_ooms = torch.zeros([], device=cfg.device) + n_ooms = torch.zeros([], device=device) - if cfg.trainer.aggressive_optimizations: - batch = to_device(batch, device) + batch = to_device(batch, device) if not cfg.trainer.check_for_oom: res = feeder( engine=engine, batch=batch ) @@ -336,7 +334,7 @@ class Engines(dict[str, Engine]): loss, engine_stats = res engine_stats |= self.gather_attribute("scalar") - n_ooms = torch.zeros([], device=cfg.device) + n_ooms = torch.zeros([], device=device) if cfg.trainer.aggressive_optimizations: batch = to_device(batch, 'cpu') diff --git a/vall_e/export.py b/vall_e/export.py index f46ab15..093f143 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -27,7 +27,9 @@ def main(): model = models[name] outpath = f'{args.path}/{name}.pt' - torch.save(model, outpath) + torch.save({ + 'module': model.state_dict() + }, outpath) print(f"Exported {name} to {outpath}") if __name__ == "__main__": diff --git a/vall_e/inference.py b/vall_e/inference.py index 691fea0..e6c8193 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -9,6 +9,8 @@ from .utils import to_device from .config import cfg from .export import load_models +from .models import get_models +from .data import get_phone_symmap class TTS(): def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ): @@ -19,11 +21,21 @@ class TTS(): self.output_sample_rate = 24000 if ar_ckpt and nar_ckpt: - self.load_ar( ar_ckpt ) - self.load_nar( nar_ckpt ) + self.ar_ckpt = ar_ckpt + self.nar_ckpt = nar_ckpt + + models = get_models(cfg.models.get()) + for name, model in models.items(): + if name.startswith("ar"): + self.ar = model.to(self.device, dtype=torch.float32) + self.ar.load_state_dict(torch.load(self.ar_ckpt)['module']) + elif name.startswith("nar"): + self.nar = model.to(self.device, dtype=torch.float32) + self.nar.load_state_dict(torch.load(self.nar_ckpt)['module']) else: self.load_models( config ) + self.symmap = get_phone_symmap() self.ar.eval() self.nar.eval() @@ -39,24 +51,13 @@ class TTS(): for name in models: model = models[name] if name[:2] == "ar": - self.ar = model.to(self.device) + self.ar = model.to(self.device, dtype=torch.float32) self.symmap = self.ar.phone_symmap elif name[:3] == "nar": - self.nar = model.to(self.device) + self.nar = model.to(self.device, dtype=torch.float32) else: print("Unknown:", name) - def load_ar( self, ckpt ): - self.ar_ckpt = ckpt - - self.ar = torch.load(self.ar_ckpt).to(self.device) - self.symmap = self.ar.phone_symmap - - def load_nar( self, ckpt ): - self.nar_ckpt = nar_ckpt - - self.nar = torch.load(self.nar_ckpt).to(self.device) - def encode_text( self, text, lang_marker="en" ): text = g2p.encode(text) phones = [f"<{lang_marker}>"] + [ " " if not p else p for p in text ] + [f""] @@ -70,7 +71,7 @@ class TTS(): def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ): prom = self.encode_audio( reference ) - phns = self.encode_text(text) + phns = self.encode_text( text ) prom = to_device(prom, self.device).to(torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 4c5e293..7ef8158 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -114,7 +114,7 @@ def example_usage(): from ..engines import Engine from tqdm import tqdm - device = "cpu" + device = "cuda" x8 = partial(repeat, pattern="t -> t l", l=2) symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} def tokenize(content, lang_marker="en"): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index aef2255..71fd3c2 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -272,6 +272,7 @@ class Base(nn.Module): """ batch_size = len(text_list) + x_list = self._samplewise_merge_tensors( self.text_emb(text_list), self.proms_emb(proms_list), @@ -281,6 +282,7 @@ class Base(nn.Module): x, m = list_to_tensor(x_list) + if self.arch_type == "transformer": x = self.sin_emb.add_pe(x) for block in self.blocks: diff --git a/vall_e/train.py b/vall_e/train.py index 51ab2cc..b5e4499 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -19,7 +19,7 @@ from collections import defaultdict from tqdm import tqdm -mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda") +mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu") def center_crop(x, len): start = (x.shape[-1] - len) // 2 @@ -89,10 +89,10 @@ def run_eval(engines, eval_name, dl): min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) ref_audio = ref_audio[..., 0:min_length] hyp_audio = hyp_audio[..., 0:min_length] - try: stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) except Exception as e: + stats['loss'].append(0) print(str(e)) for batch in tqdm(dl): diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 89de136..1174952 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -80,7 +80,7 @@ def load_engines(): if cfg.trainer.load_state_dict: load_path = cfg.ckpt_dir / name / "fp32.pth" - model.load_state_dict(torch.load(load_path)) + model.load_state_dict(torch.load(load_path)['module']) engines[name] = Engine( model=model,