set seed on inference, since it seems to be set to 0 every time

This commit is contained in:
mrq 2024-06-19 22:10:59 -05:00
parent 0b1a71430c
commit e2c9b0465f
6 changed files with 49 additions and 6 deletions

View File

@ -33,6 +33,17 @@ To start the trainer, run `python3 -m tortoise_tts.train --yaml="./path/to/your/
For training a LoRA, uncomment the `loras` block in your training YAML. For training a LoRA, uncomment the `loras` block in your training YAML.
For loading an existing finetuned model, create a folder with this structure, and load its accompanying YAML:
```
./some/arbitrary/path/:
ckpt:
autoregressive:
fp32.pth # finetuned weights
config.yaml
```
For LoRAs, replace the above `fp32.pth` with `lora.pth`.
## To-Do ## To-Do
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`) - [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
@ -54,12 +65,14 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
- [x] Web UI - [x] Web UI
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) - [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
- Although I feel a lot of its features are the wrong way to go about it. - Although I feel a lot of its features are the wrong way to go about it.
- [ ] Additional samplers for the autoregressive model - [ ] Additional samplers for the autoregressive model (such as mirostat / dynamic temperature)
- [ ] Additional samplers for the diffusion model - [ ] Additional samplers for the diffusion model (beyond the already included DDIM)
- [ ] BigVGAN in place of the original vocoder - [X] BigVGAN in place of the original vocoder
- [X] HiFiGAN integration as well
- [ ] XFormers / flash_attention_2 for the autoregressive model - [ ] XFormers / flash_attention_2 for the autoregressive model
- Beyond HF's internal implementation of handling alternative attention - Beyond HF's internal implementation of handling alternative attention
- Both the AR and diffusion models also do their own attention... - Both the AR and diffusion models also do their own attention...
- [ ] Saner way of loading finetuned models / LoRAs
- [ ] Some vector embedding store to find the "best" utterance to pick - [ ] Some vector embedding store to find the "best" utterance to pick
- [ ] Documentation - [ ] Documentation

View File

@ -26,6 +26,8 @@ def main():
parser.add_argument("--cond-free", action="store_true") parser.add_argument("--cond-free", action="store_true")
parser.add_argument("--vocoder", type=str, default="bigvgan") parser.add_argument("--vocoder", type=str, default="bigvgan")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
@ -65,6 +67,8 @@ def main():
cond_free=args.cond_free, cond_free=args.cond_free,
vocoder_type=args.vocoder, vocoder_type=args.vocoder,
seed=args.seed,
) )
""" """
language=args.language, language=args.language,

View File

@ -8,9 +8,11 @@ import sys
import time import time
import argparse import argparse
import yaml import yaml
import random
import torch import torch
import numpy as np
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from functools import cached_property from functools import cached_property
@ -22,6 +24,14 @@ from .tokenizer import VoiceBpeTokenizer
# Yuck # Yuck
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
def set_seed(seed=None):
if not seed:
seed = time.time()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml' DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml'
@dataclass() @dataclass()

View File

@ -1,6 +1,7 @@
import torch import torch
import torchaudio import torchaudio
import soundfile import soundfile
import time
from torch import Tensor from torch import Tensor
from einops import rearrange from einops import rearrange
@ -8,8 +9,7 @@ from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from .emb.mel import encode_from_files as encode_mel, trim, trim_random from .emb.mel import encode_from_files as encode_mel, trim, trim_random
from .utils import to_device from .utils import to_device, set_seed, wrapper as ml
from .utils import wrapper as ml
from .config import cfg, DEFAULT_YAML from .config import cfg, DEFAULT_YAML
from .models import get_models, load_model from .models import get_models, load_model
@ -140,6 +140,8 @@ class TTS():
vocoder_type="bigvgan", vocoder_type="bigvgan",
seed=None,
out_path=None, out_path=None,
): ):
lines = text.split("\n") lines = text.split("\n")
@ -189,12 +191,14 @@ class TTS():
candidates = 1 candidates = 1
set_seed(seed)
for line in lines: for line in lines:
if out_path is None: if out_path is None:
output_dir = Path("./data/results/") output_dir = Path("./data/results/")
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{cfg.start_time}.wav" out_path = output_dir / f"{time.time()}.wav"
text = self.encode_text( line ).to(device=cfg.device) text = self.encode_text( line ).to(device=cfg.device)

View File

@ -7,4 +7,5 @@ from .utils import (
to_device, to_device,
tree_map, tree_map,
do_gc, do_gc,
set_seed,
) )

View File

@ -7,8 +7,11 @@ from .distributed import global_rank, local_rank, global_leader_only
import gc import gc
import logging import logging
import pandas as pd import pandas as pd
import numpy as np
import re import re
import torch import torch
import random
import time
from coloredlogs import ColoredFormatter from coloredlogs import ColoredFormatter
from logging import StreamHandler from logging import StreamHandler
@ -35,6 +38,14 @@ def flatten_dict(d):
return records[0] if records else {} return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def _get_named_modules(module, attrname): def _get_named_modules(module, attrname):
for name, module in module.named_modules(): for name, module in module.named_modules():
if hasattr(module, attrname): if hasattr(module, attrname):