set seed on inference, since it seems to be set to 0 every time

This commit is contained in:
mrq 2024-06-19 22:10:59 -05:00
parent 0b1a71430c
commit e2c9b0465f
6 changed files with 49 additions and 6 deletions

View File

@ -33,6 +33,17 @@ To start the trainer, run `python3 -m tortoise_tts.train --yaml="./path/to/your/
For training a LoRA, uncomment the `loras` block in your training YAML.
For loading an existing finetuned model, create a folder with this structure, and load its accompanying YAML:
```
./some/arbitrary/path/:
ckpt:
autoregressive:
fp32.pth # finetuned weights
config.yaml
```
For LoRAs, replace the above `fp32.pth` with `lora.pth`.
## To-Do
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
@ -54,12 +65,14 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
- [x] Web UI
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
- Although I feel a lot of its features are the wrong way to go about it.
- [ ] Additional samplers for the autoregressive model
- [ ] Additional samplers for the diffusion model
- [ ] BigVGAN in place of the original vocoder
- [ ] Additional samplers for the autoregressive model (such as mirostat / dynamic temperature)
- [ ] Additional samplers for the diffusion model (beyond the already included DDIM)
- [X] BigVGAN in place of the original vocoder
- [X] HiFiGAN integration as well
- [ ] XFormers / flash_attention_2 for the autoregressive model
- Beyond HF's internal implementation of handling alternative attention
- Both the AR and diffusion models also do their own attention...
- [ ] Saner way of loading finetuned models / LoRAs
- [ ] Some vector embedding store to find the "best" utterance to pick
- [ ] Documentation

View File

@ -26,6 +26,8 @@ def main():
parser.add_argument("--cond-free", action="store_true")
parser.add_argument("--vocoder", type=str, default="bigvgan")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
@ -65,6 +67,8 @@ def main():
cond_free=args.cond_free,
vocoder_type=args.vocoder,
seed=args.seed,
)
"""
language=args.language,

View File

@ -8,9 +8,11 @@ import sys
import time
import argparse
import yaml
import random
import torch
import numpy as np
from dataclasses import asdict, dataclass, field
from functools import cached_property
@ -22,6 +24,14 @@ from .tokenizer import VoiceBpeTokenizer
# Yuck
from transformers import PreTrainedTokenizerFast
def set_seed(seed=None):
if not seed:
seed = time.time()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml'
@dataclass()

View File

@ -1,6 +1,7 @@
import torch
import torchaudio
import soundfile
import time
from torch import Tensor
from einops import rearrange
@ -8,8 +9,7 @@ from pathlib import Path
from tqdm import tqdm
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
from .utils import to_device
from .utils import wrapper as ml
from .utils import to_device, set_seed, wrapper as ml
from .config import cfg, DEFAULT_YAML
from .models import get_models, load_model
@ -140,6 +140,8 @@ class TTS():
vocoder_type="bigvgan",
seed=None,
out_path=None,
):
lines = text.split("\n")
@ -189,12 +191,14 @@ class TTS():
candidates = 1
set_seed(seed)
for line in lines:
if out_path is None:
output_dir = Path("./data/results/")
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{cfg.start_time}.wav"
out_path = output_dir / f"{time.time()}.wav"
text = self.encode_text( line ).to(device=cfg.device)

View File

@ -7,4 +7,5 @@ from .utils import (
to_device,
tree_map,
do_gc,
set_seed,
)

View File

@ -7,8 +7,11 @@ from .distributed import global_rank, local_rank, global_leader_only
import gc
import logging
import pandas as pd
import numpy as np
import re
import torch
import random
import time
from coloredlogs import ColoredFormatter
from logging import StreamHandler
@ -35,6 +38,14 @@ def flatten_dict(d):
return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def _get_named_modules(module, attrname):
for name, module in module.named_modules():
if hasattr(module, attrname):