diff --git a/config/LibriTTS/ar.yml b/config/LibriTTS/ar.yml index 5fa8950..5810dc4 100644 --- a/config/LibriTTS/ar.yml +++ b/config/LibriTTS/ar.yml @@ -3,3 +3,4 @@ spkr_name_getter: "lambda p: p.parts[-3]" model: ar-quarter batch_size: 8 +eval_batch_size: 8 diff --git a/config/LibriTTS/nar.yml b/config/LibriTTS/nar.yml index 3c89829..997b31e 100644 --- a/config/LibriTTS/nar.yml +++ b/config/LibriTTS/nar.yml @@ -3,3 +3,4 @@ spkr_name_getter: "lambda p: p.parts[-3]" model: nar-quarter batch_size: 8 +eval_batch_size: 8 diff --git a/scripts/run.sh b/scripts/run.sh new file mode 100755 index 0000000..8c709d4 --- /dev/null +++ b/scripts/run.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +until $@; do echo retrying; done diff --git a/vall_e/__main__.py b/vall_e/__main__.py new file mode 100644 index 0000000..9f416d0 --- /dev/null +++ b/vall_e/__main__.py @@ -0,0 +1,14 @@ +import argparse +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser("VALL-E TTS") + parser.add_argument("text") + parser.add_argument("output") + parser.add_argument("--reference", type=Path) + args = parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/vall_e/config.py b/vall_e/config.py index e72741d..b75f3cd 100644 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -17,7 +17,8 @@ class Config(ConfigBase): def sample_rate(self): return 24_000 - p_additional_prompt: float = 0.5 + p_additional_prompt: float = 0.8 + max_prompts: int = 3 token_dim: int = 256 num_tokens: int = 1024 @@ -37,6 +38,9 @@ class Config(ConfigBase): model: str = "ar-quarter" spkr_name_getter: str = "lambda p: p.parts[-2]" + min_phones: int = 10 + max_phones: int = 50 + @cached_property def get_spkr(self): return eval(self.spkr_name_getter) diff --git a/vall_e/data.py b/vall_e/data.py index 6e358ae..db492bc 100644 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -81,8 +81,8 @@ class VALLEDatset(Dataset): paths, phone_symmap=None, spkr_symmap=None, - min_phones=10, - max_phones=100, + min_phones=cfg.min_phones, + max_phones=cfg.max_phones, training=False, extra_paths_by_spkr_name: dict[str, list] = {}, ): @@ -141,7 +141,7 @@ class VALLEDatset(Dataset): ) choices = [better_not] - for _ in range(10): + for _ in range(cfg.max_prompts): path = random.choice(choices) prom_list.append(_load_quants(path)) if random.random() > cfg.p_additional_prompt: