Better prompt sampling

This commit is contained in:
enhuiz 2023-01-12 19:46:05 +08:00
parent 3a4d5be18b
commit 958c2df660
2 changed files with 19 additions and 9 deletions

View File

@ -72,7 +72,7 @@ def _validate(path, min_phones, max_phones):
def _get_spkr_name(path) -> str:
return path.parts[-2] # spkr/*.wav
return cfg.get_spkr(path)
class VALLEDatset(Dataset):
@ -128,16 +128,24 @@ class VALLEDatset(Dataset):
def _get_spkr_symmap(self):
return {s: i for i, s in enumerate(self.spkrs)}
def sample_prompts(self, spkr_name):
def sample_prompts(self, spkr_name, better_not):
prom_list = []
while (
len(prom_list) == 0
or random.random() < cfg.p_additional_prompt
and len(prom_list) < 10
):
path = random.choice(self.paths_by_spkr_name[spkr_name])
choices = set(self.paths_by_spkr_name[spkr_name]) - {better_not}
choices = [*choices]
if len(choices) == 0:
_logger.info(
f"Failed to find another different utterance for {spkr_name}, "
"using the same audio as prompt."
)
choices = [better_not]
for _ in range(10):
path = random.choice(choices)
prom_list.append(_load_quants(path))
if random.random() > cfg.p_additional_prompt:
break
prom = torch.cat(prom_list)
@ -152,7 +160,7 @@ class VALLEDatset(Dataset):
spkr_name = _get_spkr_name(path)
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
proms = self.sample_prompts(spkr_name)
proms = self.sample_prompts(spkr_name, better_not=path)
resps = _load_quants(path)
resp = resps[..., 0]

View File

@ -40,6 +40,8 @@ def main():
for path in tqdm(paths):
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
if phone_path.exists():
continue
graphs = _get_graphs(path)
phones = encode(graphs)
with open(phone_path, "w") as f: