set seed on inference, since it seems to be set to 0 every time
This commit is contained in:
parent
0b1a71430c
commit
e2c9b0465f
19
README.md
19
README.md
|
@ -33,6 +33,17 @@ To start the trainer, run `python3 -m tortoise_tts.train --yaml="./path/to/your/
|
|||
|
||||
For training a LoRA, uncomment the `loras` block in your training YAML.
|
||||
|
||||
For loading an existing finetuned model, create a folder with this structure, and load its accompanying YAML:
|
||||
```
|
||||
./some/arbitrary/path/:
|
||||
ckpt:
|
||||
autoregressive:
|
||||
fp32.pth # finetuned weights
|
||||
config.yaml
|
||||
```
|
||||
|
||||
For LoRAs, replace the above `fp32.pth` with `lora.pth`.
|
||||
|
||||
## To-Do
|
||||
|
||||
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
|
||||
|
@ -54,12 +65,14 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
|
|||
- [x] Web UI
|
||||
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
|
||||
- Although I feel a lot of its features are the wrong way to go about it.
|
||||
- [ ] Additional samplers for the autoregressive model
|
||||
- [ ] Additional samplers for the diffusion model
|
||||
- [ ] BigVGAN in place of the original vocoder
|
||||
- [ ] Additional samplers for the autoregressive model (such as mirostat / dynamic temperature)
|
||||
- [ ] Additional samplers for the diffusion model (beyond the already included DDIM)
|
||||
- [X] BigVGAN in place of the original vocoder
|
||||
- [X] HiFiGAN integration as well
|
||||
- [ ] XFormers / flash_attention_2 for the autoregressive model
|
||||
- Beyond HF's internal implementation of handling alternative attention
|
||||
- Both the AR and diffusion models also do their own attention...
|
||||
- [ ] Saner way of loading finetuned models / LoRAs
|
||||
- [ ] Some vector embedding store to find the "best" utterance to pick
|
||||
- [ ] Documentation
|
||||
|
||||
|
|
|
@ -26,6 +26,8 @@ def main():
|
|||
parser.add_argument("--cond-free", action="store_true")
|
||||
parser.add_argument("--vocoder", type=str, default="bigvgan")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
|
@ -65,6 +67,8 @@ def main():
|
|||
cond_free=args.cond_free,
|
||||
|
||||
vocoder_type=args.vocoder,
|
||||
|
||||
seed=args.seed,
|
||||
)
|
||||
"""
|
||||
language=args.language,
|
||||
|
|
|
@ -8,9 +8,11 @@ import sys
|
|||
import time
|
||||
import argparse
|
||||
import yaml
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
from functools import cached_property
|
||||
|
@ -22,6 +24,14 @@ from .tokenizer import VoiceBpeTokenizer
|
|||
# Yuck
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
def set_seed(seed=None):
|
||||
if not seed:
|
||||
seed = time.time()
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml'
|
||||
|
||||
@dataclass()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
import time
|
||||
|
||||
from torch import Tensor
|
||||
from einops import rearrange
|
||||
|
@ -8,8 +9,7 @@ from pathlib import Path
|
|||
from tqdm import tqdm
|
||||
|
||||
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
|
||||
from .utils import to_device
|
||||
from .utils import wrapper as ml
|
||||
from .utils import to_device, set_seed, wrapper as ml
|
||||
|
||||
from .config import cfg, DEFAULT_YAML
|
||||
from .models import get_models, load_model
|
||||
|
@ -140,6 +140,8 @@ class TTS():
|
|||
|
||||
vocoder_type="bigvgan",
|
||||
|
||||
seed=None,
|
||||
|
||||
out_path=None,
|
||||
):
|
||||
lines = text.split("\n")
|
||||
|
@ -189,12 +191,14 @@ class TTS():
|
|||
|
||||
candidates = 1
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
for line in lines:
|
||||
if out_path is None:
|
||||
output_dir = Path("./data/results/")
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = output_dir / f"{cfg.start_time}.wav"
|
||||
out_path = output_dir / f"{time.time()}.wav"
|
||||
|
||||
text = self.encode_text( line ).to(device=cfg.device)
|
||||
|
||||
|
|
|
@ -7,4 +7,5 @@ from .utils import (
|
|||
to_device,
|
||||
tree_map,
|
||||
do_gc,
|
||||
set_seed,
|
||||
)
|
|
@ -7,8 +7,11 @@ from .distributed import global_rank, local_rank, global_leader_only
|
|||
import gc
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import re
|
||||
import torch
|
||||
import random
|
||||
import time
|
||||
|
||||
from coloredlogs import ColoredFormatter
|
||||
from logging import StreamHandler
|
||||
|
@ -35,6 +38,14 @@ def flatten_dict(d):
|
|||
return records[0] if records else {}
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
if not seed:
|
||||
seed = int(time.time())
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def _get_named_modules(module, attrname):
|
||||
for name, module in module.named_modules():
|
||||
if hasattr(module, attrname):
|
||||
|
|
Loading…
Reference in New Issue
Block a user