From e5136613f55361ed91a863d048f42c1c72d20cbd Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 10:08:14 -0500 Subject: [PATCH] semblance of documentation, automagic model downloading, a little saner inference results folder --- README.md | 33 +++++++- data/config.yaml | 133 ++++++++++++++++++++++++++++++ data/{ => models}/mel_norms.pth | Bin tortoise_tts/inference.py | 5 +- tortoise_tts/models/__init__.py | 73 ++++++++++++++-- tortoise_tts/models/arch_utils.py | 6 +- tortoise_tts/tokenizer.py | 3 + 7 files changed, 238 insertions(+), 15 deletions(-) create mode 100644 data/config.yaml rename data/{ => models}/mel_norms.pth (100%) diff --git a/README.md b/README.md index df6fbbc..8a5828d 100644 --- a/README.md +++ b/README.md @@ -2,23 +2,49 @@ An unofficial PyTorch re-implementation of [TorToise TTS](https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af). +Almost all of the documentation and usage are carried over from my [VALL-E](https://github.com/e-c-k-e-r/vall-e) implementation, as documentation is lacking for this implementation, as I whipped it up over the course of two days using knowledge I haven't touched in a year. + ## Requirements A working PyTorch environment. ++ `python3 -m venv venv && source ./venv/bin/activate` is sufficient. ## Install -Simply run `pip install git+https://git.ecker.tech/mrq/tortoise-tts` or `pip install git+https://github.com/e-c-k-e-r/tortoise-tts`. +Simply run `pip install git+https://git.ecker.tech/mrq/tortoise-tts@new` or `pip install git+https://github.com/e-c-k-e-r/tortoise-tts`. + +## Usage + +### Inferencing + +Using the default settings: `python3 -m tortoise_tts --yaml="./data/config.yaml" "Read verse out loud for pleasure." "./path/to/a.wav"` + +To inference using the included Web UI: `python3 -m tortoise_tts.webui --yaml="./data/config.yaml"` ++ Pass `--listen 0.0.0.0:7860` if you're accessing the web UI from outside of `localhost` (or pass the host machine's local IP instead) + +### Training / Finetuning + +Training is as simple as copying the reference YAML from `./data/config.yaml` to any training directory of your choice (for examples: `./training/` or `./training/lora-finetune/`). + +A pre-processed dataset is required. Refer to [the VALL-E implementation](https://github.com/e-c-k-e-r/vall-e#leverage-your-own-dataset) for more details. + +To start the trainer, run `python3 -m tortoise_tts.train --yaml="./path/to/your/training/config.yaml`. ++ Type `save` to save whenever. Type `quit` to quit and save whenever. Type `eval` to run evaluation / validation of the model. + +For training a LoRA, uncomment the `loras` block in your training YAML. ## To-Do - [X] Reimplement original inferencing through TorToiSe (as done with `api.py`) + - [ ] Reimplement candidate selection with the CLVP - [X] Implement training support (without DLAS) - [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time -- [ ] Automagic handling of the original weights into compatible weights +- [ ] Automagic offloading to CPU for unused models (for training and inferencing) +- [X] Automagic handling of the original weights into compatible weights - [ ] Extend the original inference routine with additional features: - - [x] non-float32 / mixed precision + - [ ] non-float32 / mixed precision for the entire stack - [x] BitsAndBytes support + - Provided Linears technically aren't used because GPT2 uses Conv1D instead... - [x] LoRAs - [x] Web UI - [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) @@ -27,6 +53,7 @@ Simply run `pip install git+https://git.ecker.tech/mrq/tortoise-tts` or `pip ins - [ ] BigVGAN in place of the original vocoder - [ ] XFormers / flash_attention_2 for the autoregressive model - [ ] Some vector embedding store to find the "best" utterance to pick +- [ ] Documentation ## Why? diff --git a/data/config.yaml b/data/config.yaml new file mode 100644 index 0000000..f9181ab --- /dev/null +++ b/data/config.yaml @@ -0,0 +1,133 @@ +models: +- name: "autoregressive" + training: True + +#loras: +#- name : "lora-test" +# rank: 128 +# alpha: 128 +# training: True +# parametrize: True + +hyperparameters: + autotune: False + autotune_params: + start_profile_step: 1 + end_profile_step: 50 + num_tuning_micro_batch_sizes: 8 + + batch_size: 4 + gradient_accumulation_steps: 2 + gradient_clipping: 1.0 + warmup_steps: 0 + + optimizer: AdamW + learning_rate: 1.0e-4 + # optimizer: Prodigy + # learning_rate: 1.0 + torch_optimizer: True + + scheduler: "" # ScheduleFree + torch_scheduler: True + +evaluation: + batch_size: 4 + frequency: 1000 + size: 4 + + steps: 500 + ar_temperature: 0.95 + nar_temperature: 0.25 + load_disabled_engines: True + +trainer: + #no_logger: True + ddp: False + check_for_oom: False + iterations: 1_000_000 + + save_tag: step + save_on_oom: True + save_on_quit: True + save_frequency: 500 + export_on_save: True + + keep_last_checkpoints: 8 + + aggressive_optimizations: False + load_disabled_engines: False + gradient_checkpointing: True + + #load_state_dict: True + strict_loading: False + #load_tag: "9500" + #load_states: False + #restart_step_count: True + + gc_mode: None # "global_step" + + weight_dtype: bfloat16 + amp: True + + backend: deepspeed + deepspeed: + inferencing: False + zero_optimization_level: 0 + use_compression_training: False + + amp: False + + load_webui: False + +inference: + backend: deepspeed + normalize: False + + # some steps break under blanket (B)FP16 + AMP + weight_dtype: float32 + amp: False + +optimizations: + injects: False + replace: True + + linear: False + embedding: False + optimizers: True + + bitsandbytes: True + dadaptation: False + bitnet: False + fp8: False + +dataset: + speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" + speaker_group_getter: "lambda p: f'{p.parts[-3]}'" + speaker_languages: + ja: [] + + use_hdf5: True + use_metadata: True + hdf5_flag: r + validate: True + + workers: 6 + cache: True + + duration_range: [2.0, 3.0] + + random_utterance: 1.0 + max_prompts: 1 + prompt_duration_range: [3.0, 3.0] + + max_resps: 1 + p_resp_append: 0.25 + + sample_type: path # path | speaker | group + sample_order: duration # duration | shuffle + + tasks_list: [ "tts" ] + + training: [] + validation: [] + noise: [] diff --git a/data/mel_norms.pth b/data/models/mel_norms.pth similarity index 100% rename from data/mel_norms.pth rename to data/models/mel_norms.pth diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index 99c0be8..11cef83 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -154,7 +154,10 @@ class TTS(): for line in lines: if out_path is None: - out_path = f"./data/{cfg.start_time}.wav" + output_dir = Path("./data/results/") + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + out_path = output_dir / f"{cfg.start_time}.wav" text = self.encode_text( line ).to(device=cfg.device) diff --git a/tortoise_tts/models/__init__.py b/tortoise_tts/models/__init__.py index d77014f..a061906 100755 --- a/tortoise_tts/models/__init__.py +++ b/tortoise_tts/models/__init__.py @@ -14,8 +14,61 @@ from .random_latent_generator import RandomLatentConverter import os import torch +from pathlib import Path -DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/') +DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models' +DEFAULT_MODEL_URLS = { + 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth', + 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth', + 'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth', + 'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth', + 'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth', + 'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth', + 'dvae.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth', + 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', + 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', + 'mel_norms.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/data/mel_norms.pth', +} + +import requests +from tqdm import tqdm + +# kludge, probably better to use HF's model downloader function +# to-do: write to a temp file then copy so downloads can be interrupted +def download_model( save_path, chunkSize = 1024, unit = "MiB" ): + scale = 1 + if unit == "KiB": + scale = (1024) + elif unit == "MiB": + scale = (1024 * 1024) + elif unit == "MiB": + scale = (1024 * 1024 * 1024) + elif unit == "KB": + scale = (1000) + elif unit == "MB": + scale = (1000 * 1000) + elif unit == "MB": + scale = (1000 * 1000 * 1000) + + name = save_path.name + url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None + if url is None: + raise Exception(f'Model requested for download but not defined: {name}') + + if not save_path.parent.exists(): + save_path.parent.mkdir(parents=True, exist_ok=True) + + r = requests.get(url, stream=True) + content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale + + with open(save_path, 'wb') as f: + bar = tqdm( unit=unit, total=content_length ) + for chunk in r.iter_content(chunk_size=chunkSize): + if not chunk: + continue + + bar.update( len(chunk) / scale ) + f.write(chunk) # semi-necessary as a way to provide a mechanism for other portions of the program to access models @cache @@ -27,26 +80,26 @@ def load_model(name, device="cuda", **kwargs): if "rlg" in name: if "autoregressive" in name: model = RandomLatentConverter(1024, **kwargs) - load_path = f'{DEFAULT_MODEL_PATH}/rlg_auto.pth' + load_path = DEFAULT_MODEL_PATH / 'rlg_auto.pth' if "diffusion" in name: model = RandomLatentConverter(2048, **kwargs) - load_path = f'{DEFAULT_MODEL_PATH}/rlg_diffuser.pth' + load_path = DEFAULT_MODEL_PATH / 'rlg_diffuser.pth' elif "autoregressive" in name or "unified_voice" in name: strict = False model = UnifiedVoice(**kwargs) - load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth' + load_path = DEFAULT_MODEL_PATH / 'autoregressive.pth' elif "diffusion" in name: model = DiffusionTTS(**kwargs) - load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth' + load_path = DEFAULT_MODEL_PATH / 'diffusion.pth' elif "clvp" in name: model = CLVP(**kwargs) - load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth' + load_path = DEFAULT_MODEL_PATH / 'clvp2.pth' elif "vocoder" in name: model = UnivNetGenerator(**kwargs) - load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth' + load_path = DEFAULT_MODEL_PATH / 'vocoder.pth' state_dict_key = 'model_g' elif "dvae" in name: - load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth' + load_path = DEFAULT_MODEL_PATH / 'dvae.pth' model = DiscreteVAE(**kwargs) # to-do: figure out of the below two give the exact same output elif "stft" in name: @@ -61,6 +114,10 @@ def load_model(name, device="cuda", **kwargs): model = model.to(device=device) if load_path is not None: + # download if does not exist + if not load_path.exists(): + download_model( load_path ) + state_dict = torch.load(load_path, map_location=device) if state_dict_key: state_dict = state_dict[state_dict_key] diff --git a/tortoise_tts/models/arch_utils.py b/tortoise_tts/models/arch_utils.py index 832aadb..231cb51 100644 --- a/tortoise_tts/models/arch_utils.py +++ b/tortoise_tts/models/arch_utils.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio +from pathlib import Path from .xtransformers import ContinuousTransformerWrapper, RelativePositionBias def zero_module(module): @@ -289,8 +290,7 @@ class AudioMiniEncoder(nn.Module): return h[:, :, 0] -DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/mel_norms.pth') - +DEFAULT_MEL_NORM_FILE = Path(__file__).parent.parent.parent / 'data/models/mel_norms.pth' class TorchMelSpectrogram(nn.Module): def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000, @@ -310,7 +310,7 @@ class TorchMelSpectrogram(nn.Module): f_max=self.mel_fmax, n_mels=self.n_mel_channels, norm="slaney") self.mel_norm_file = mel_norm_file - if self.mel_norm_file is not None: + if self.mel_norm_file is not None and self.mel_norm_file.exists(): self.mel_norms = torch.load(self.mel_norm_file) else: self.mel_norms = None diff --git a/tortoise_tts/tokenizer.py b/tortoise_tts/tokenizer.py index 59c3645..6b69aae 100644 --- a/tortoise_tts/tokenizer.py +++ b/tortoise_tts/tokenizer.py @@ -1,3 +1,6 @@ +# to-do: make use of tokenizer's configurable preprocessors +# it *might* be required to keep all of this to maintain tokenizer compatibility + import os import re