From e2c9b0465f9c0fb6550915959a863bcb9b54f982 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 22:10:59 -0500 Subject: [PATCH] set seed on inference, since it seems to be set to 0 every time --- README.md | 19 ++++++++++++++++--- tortoise_tts/__main__.py | 4 ++++ tortoise_tts/config.py | 10 ++++++++++ tortoise_tts/inference.py | 10 +++++++--- tortoise_tts/utils/__init__.py | 1 + tortoise_tts/utils/utils.py | 11 +++++++++++ 6 files changed, 49 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 38564e3..24ea1c6 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,17 @@ To start the trainer, run `python3 -m tortoise_tts.train --yaml="./path/to/your/ For training a LoRA, uncomment the `loras` block in your training YAML. +For loading an existing finetuned model, create a folder with this structure, and load its accompanying YAML: +``` +./some/arbitrary/path/: + ckpt: + autoregressive: + fp32.pth # finetuned weights + config.yaml +``` + +For LoRAs, replace the above `fp32.pth` with `lora.pth`. + ## To-Do - [X] Reimplement original inferencing through TorToiSe (as done with `api.py`) @@ -54,12 +65,14 @@ For training a LoRA, uncomment the `loras` block in your training YAML. - [x] Web UI - [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) - Although I feel a lot of its features are the wrong way to go about it. - - [ ] Additional samplers for the autoregressive model - - [ ] Additional samplers for the diffusion model - - [ ] BigVGAN in place of the original vocoder + - [ ] Additional samplers for the autoregressive model (such as mirostat / dynamic temperature) + - [ ] Additional samplers for the diffusion model (beyond the already included DDIM) + - [X] BigVGAN in place of the original vocoder + - [X] HiFiGAN integration as well - [ ] XFormers / flash_attention_2 for the autoregressive model - Beyond HF's internal implementation of handling alternative attention - Both the AR and diffusion models also do their own attention... + - [ ] Saner way of loading finetuned models / LoRAs - [ ] Some vector embedding store to find the "best" utterance to pick - [ ] Documentation diff --git a/tortoise_tts/__main__.py b/tortoise_tts/__main__.py index c60dda6..481fc5c 100755 --- a/tortoise_tts/__main__.py +++ b/tortoise_tts/__main__.py @@ -25,6 +25,8 @@ def main(): parser.add_argument("--diffusion-sampler", type=str, default="ddim") parser.add_argument("--cond-free", action="store_true") parser.add_argument("--vocoder", type=str, default="bigvgan") + + parser.add_argument("--seed", type=int, default=None) parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--device", type=str, default=None) @@ -65,6 +67,8 @@ def main(): cond_free=args.cond_free, vocoder_type=args.vocoder, + + seed=args.seed, ) """ language=args.language, diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 7cbaa84..18f10ca 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -8,9 +8,11 @@ import sys import time import argparse import yaml +import random import torch +import numpy as np from dataclasses import asdict, dataclass, field from functools import cached_property @@ -22,6 +24,14 @@ from .tokenizer import VoiceBpeTokenizer # Yuck from transformers import PreTrainedTokenizerFast +def set_seed(seed=None): + if not seed: + seed = time.time() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml' @dataclass() diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index b739321..eb4dfd2 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -1,6 +1,7 @@ import torch import torchaudio import soundfile +import time from torch import Tensor from einops import rearrange @@ -8,8 +9,7 @@ from pathlib import Path from tqdm import tqdm from .emb.mel import encode_from_files as encode_mel, trim, trim_random -from .utils import to_device -from .utils import wrapper as ml +from .utils import to_device, set_seed, wrapper as ml from .config import cfg, DEFAULT_YAML from .models import get_models, load_model @@ -140,6 +140,8 @@ class TTS(): vocoder_type="bigvgan", + seed=None, + out_path=None, ): lines = text.split("\n") @@ -189,12 +191,14 @@ class TTS(): candidates = 1 + set_seed(seed) + for line in lines: if out_path is None: output_dir = Path("./data/results/") if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=True) - out_path = output_dir / f"{cfg.start_time}.wav" + out_path = output_dir / f"{time.time()}.wav" text = self.encode_text( line ).to(device=cfg.device) diff --git a/tortoise_tts/utils/__init__.py b/tortoise_tts/utils/__init__.py index 96929f3..b2f2ef9 100755 --- a/tortoise_tts/utils/__init__.py +++ b/tortoise_tts/utils/__init__.py @@ -7,4 +7,5 @@ from .utils import ( to_device, tree_map, do_gc, + set_seed, ) \ No newline at end of file diff --git a/tortoise_tts/utils/utils.py b/tortoise_tts/utils/utils.py index e92239a..ee37edc 100755 --- a/tortoise_tts/utils/utils.py +++ b/tortoise_tts/utils/utils.py @@ -7,8 +7,11 @@ from .distributed import global_rank, local_rank, global_leader_only import gc import logging import pandas as pd +import numpy as np import re import torch +import random +import time from coloredlogs import ColoredFormatter from logging import StreamHandler @@ -35,6 +38,14 @@ def flatten_dict(d): return records[0] if records else {} +def set_seed(seed=None): + if not seed: + seed = int(time.time()) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + def _get_named_modules(module, attrname): for name, module in module.named_modules(): if hasattr(module, attrname):