Make more VRAM friendly
This commit is contained in:
parent
f91db1a64c
commit
52998447b7
|
@ -3,3 +3,4 @@ spkr_name_getter: "lambda p: p.parts[-3]"
|
|||
|
||||
model: ar-quarter
|
||||
batch_size: 8
|
||||
eval_batch_size: 8
|
||||
|
|
|
@ -3,3 +3,4 @@ spkr_name_getter: "lambda p: p.parts[-3]"
|
|||
|
||||
model: nar-quarter
|
||||
batch_size: 8
|
||||
eval_batch_size: 8
|
||||
|
|
3
scripts/run.sh
Executable file
3
scripts/run.sh
Executable file
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
until $@; do echo retrying; done
|
14
vall_e/__main__.py
Normal file
14
vall_e/__main__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||
parser.add_argument("text")
|
||||
parser.add_argument("output")
|
||||
parser.add_argument("--reference", type=Path)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -17,7 +17,8 @@ class Config(ConfigBase):
|
|||
def sample_rate(self):
|
||||
return 24_000
|
||||
|
||||
p_additional_prompt: float = 0.5
|
||||
p_additional_prompt: float = 0.8
|
||||
max_prompts: int = 3
|
||||
|
||||
token_dim: int = 256
|
||||
num_tokens: int = 1024
|
||||
|
@ -37,6 +38,9 @@ class Config(ConfigBase):
|
|||
model: str = "ar-quarter"
|
||||
spkr_name_getter: str = "lambda p: p.parts[-2]"
|
||||
|
||||
min_phones: int = 10
|
||||
max_phones: int = 50
|
||||
|
||||
@cached_property
|
||||
def get_spkr(self):
|
||||
return eval(self.spkr_name_getter)
|
||||
|
|
|
@ -81,8 +81,8 @@ class VALLEDatset(Dataset):
|
|||
paths,
|
||||
phone_symmap=None,
|
||||
spkr_symmap=None,
|
||||
min_phones=10,
|
||||
max_phones=100,
|
||||
min_phones=cfg.min_phones,
|
||||
max_phones=cfg.max_phones,
|
||||
training=False,
|
||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||
):
|
||||
|
@ -141,7 +141,7 @@ class VALLEDatset(Dataset):
|
|||
)
|
||||
choices = [better_not]
|
||||
|
||||
for _ in range(10):
|
||||
for _ in range(cfg.max_prompts):
|
||||
path = random.choice(choices)
|
||||
prom_list.append(_load_quants(path))
|
||||
if random.random() > cfg.p_additional_prompt:
|
||||
|
|
Loading…
Reference in New Issue
Block a user