set seed on inference, since it seems to be set to 0 every time
This commit is contained in:
parent
0b1a71430c
commit
e2c9b0465f
19
README.md
19
README.md
|
@ -33,6 +33,17 @@ To start the trainer, run `python3 -m tortoise_tts.train --yaml="./path/to/your/
|
||||||
|
|
||||||
For training a LoRA, uncomment the `loras` block in your training YAML.
|
For training a LoRA, uncomment the `loras` block in your training YAML.
|
||||||
|
|
||||||
|
For loading an existing finetuned model, create a folder with this structure, and load its accompanying YAML:
|
||||||
|
```
|
||||||
|
./some/arbitrary/path/:
|
||||||
|
ckpt:
|
||||||
|
autoregressive:
|
||||||
|
fp32.pth # finetuned weights
|
||||||
|
config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
For LoRAs, replace the above `fp32.pth` with `lora.pth`.
|
||||||
|
|
||||||
## To-Do
|
## To-Do
|
||||||
|
|
||||||
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
|
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
|
||||||
|
@ -54,12 +65,14 @@ For training a LoRA, uncomment the `loras` block in your training YAML.
|
||||||
- [x] Web UI
|
- [x] Web UI
|
||||||
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
|
- [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning)
|
||||||
- Although I feel a lot of its features are the wrong way to go about it.
|
- Although I feel a lot of its features are the wrong way to go about it.
|
||||||
- [ ] Additional samplers for the autoregressive model
|
- [ ] Additional samplers for the autoregressive model (such as mirostat / dynamic temperature)
|
||||||
- [ ] Additional samplers for the diffusion model
|
- [ ] Additional samplers for the diffusion model (beyond the already included DDIM)
|
||||||
- [ ] BigVGAN in place of the original vocoder
|
- [X] BigVGAN in place of the original vocoder
|
||||||
|
- [X] HiFiGAN integration as well
|
||||||
- [ ] XFormers / flash_attention_2 for the autoregressive model
|
- [ ] XFormers / flash_attention_2 for the autoregressive model
|
||||||
- Beyond HF's internal implementation of handling alternative attention
|
- Beyond HF's internal implementation of handling alternative attention
|
||||||
- Both the AR and diffusion models also do their own attention...
|
- Both the AR and diffusion models also do their own attention...
|
||||||
|
- [ ] Saner way of loading finetuned models / LoRAs
|
||||||
- [ ] Some vector embedding store to find the "best" utterance to pick
|
- [ ] Some vector embedding store to find the "best" utterance to pick
|
||||||
- [ ] Documentation
|
- [ ] Documentation
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@ def main():
|
||||||
parser.add_argument("--cond-free", action="store_true")
|
parser.add_argument("--cond-free", action="store_true")
|
||||||
parser.add_argument("--vocoder", type=str, default="bigvgan")
|
parser.add_argument("--vocoder", type=str, default="bigvgan")
|
||||||
|
|
||||||
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path, default=None)
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
|
@ -65,6 +67,8 @@ def main():
|
||||||
cond_free=args.cond_free,
|
cond_free=args.cond_free,
|
||||||
|
|
||||||
vocoder_type=args.vocoder,
|
vocoder_type=args.vocoder,
|
||||||
|
|
||||||
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
language=args.language,
|
language=args.language,
|
||||||
|
|
|
@ -8,9 +8,11 @@ import sys
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import yaml
|
import yaml
|
||||||
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
@ -22,6 +24,14 @@ from .tokenizer import VoiceBpeTokenizer
|
||||||
# Yuck
|
# Yuck
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
def set_seed(seed=None):
|
||||||
|
if not seed:
|
||||||
|
seed = time.time()
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml'
|
DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml'
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import soundfile
|
import soundfile
|
||||||
|
import time
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
@ -8,8 +9,7 @@ from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
|
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
|
||||||
from .utils import to_device
|
from .utils import to_device, set_seed, wrapper as ml
|
||||||
from .utils import wrapper as ml
|
|
||||||
|
|
||||||
from .config import cfg, DEFAULT_YAML
|
from .config import cfg, DEFAULT_YAML
|
||||||
from .models import get_models, load_model
|
from .models import get_models, load_model
|
||||||
|
@ -140,6 +140,8 @@ class TTS():
|
||||||
|
|
||||||
vocoder_type="bigvgan",
|
vocoder_type="bigvgan",
|
||||||
|
|
||||||
|
seed=None,
|
||||||
|
|
||||||
out_path=None,
|
out_path=None,
|
||||||
):
|
):
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
|
@ -189,12 +191,14 @@ class TTS():
|
||||||
|
|
||||||
candidates = 1
|
candidates = 1
|
||||||
|
|
||||||
|
set_seed(seed)
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
output_dir = Path("./data/results/")
|
output_dir = Path("./data/results/")
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
out_path = output_dir / f"{cfg.start_time}.wav"
|
out_path = output_dir / f"{time.time()}.wav"
|
||||||
|
|
||||||
text = self.encode_text( line ).to(device=cfg.device)
|
text = self.encode_text( line ).to(device=cfg.device)
|
||||||
|
|
||||||
|
|
|
@ -7,4 +7,5 @@ from .utils import (
|
||||||
to_device,
|
to_device,
|
||||||
tree_map,
|
tree_map,
|
||||||
do_gc,
|
do_gc,
|
||||||
|
set_seed,
|
||||||
)
|
)
|
|
@ -7,8 +7,11 @@ from .distributed import global_rank, local_rank, global_leader_only
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
from coloredlogs import ColoredFormatter
|
from coloredlogs import ColoredFormatter
|
||||||
from logging import StreamHandler
|
from logging import StreamHandler
|
||||||
|
@ -35,6 +38,14 @@ def flatten_dict(d):
|
||||||
return records[0] if records else {}
|
return records[0] if records else {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed=None):
|
||||||
|
if not seed:
|
||||||
|
seed = int(time.time())
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
def _get_named_modules(module, attrname):
|
def _get_named_modules(module, attrname):
|
||||||
for name, module in module.named_modules():
|
for name, module in module.named_modules():
|
||||||
if hasattr(module, attrname):
|
if hasattr(module, attrname):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user