Make more VRAM friendly

This commit is contained in:
enhuiz 2023-01-12 20:07:44 +08:00
parent f91db1a64c
commit 52998447b7
6 changed files with 27 additions and 4 deletions

View File

@ -3,3 +3,4 @@ spkr_name_getter: "lambda p: p.parts[-3]"
model: ar-quarter
batch_size: 8
eval_batch_size: 8

View File

@ -3,3 +3,4 @@ spkr_name_getter: "lambda p: p.parts[-3]"
model: nar-quarter
batch_size: 8
eval_batch_size: 8

3
scripts/run.sh Executable file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env bash
until $@; do echo retrying; done

14
vall_e/__main__.py Normal file
View File

@ -0,0 +1,14 @@
import argparse
from pathlib import Path
def main():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text")
parser.add_argument("output")
parser.add_argument("--reference", type=Path)
args = parser.parse_args()
if __name__ == "__main__":
main()

View File

@ -17,7 +17,8 @@ class Config(ConfigBase):
def sample_rate(self):
return 24_000
p_additional_prompt: float = 0.5
p_additional_prompt: float = 0.8
max_prompts: int = 3
token_dim: int = 256
num_tokens: int = 1024
@ -37,6 +38,9 @@ class Config(ConfigBase):
model: str = "ar-quarter"
spkr_name_getter: str = "lambda p: p.parts[-2]"
min_phones: int = 10
max_phones: int = 50
@cached_property
def get_spkr(self):
return eval(self.spkr_name_getter)

View File

@ -81,8 +81,8 @@ class VALLEDatset(Dataset):
paths,
phone_symmap=None,
spkr_symmap=None,
min_phones=10,
max_phones=100,
min_phones=cfg.min_phones,
max_phones=cfg.max_phones,
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
):
@ -141,7 +141,7 @@ class VALLEDatset(Dataset):
)
choices = [better_not]
for _ in range(10):
for _ in range(cfg.max_prompts):
path = random.choice(choices)
prom_list.append(_load_quants(path))
if random.random() > cfg.p_additional_prompt: