Better prompt sampling
This commit is contained in:
parent
3a4d5be18b
commit
958c2df660
@ -72,7 +72,7 @@ def _validate(path, min_phones, max_phones):
|
||||
|
||||
|
||||
def _get_spkr_name(path) -> str:
|
||||
return path.parts[-2] # spkr/*.wav
|
||||
return cfg.get_spkr(path)
|
||||
|
||||
|
||||
class VALLEDatset(Dataset):
|
||||
@ -128,16 +128,24 @@ class VALLEDatset(Dataset):
|
||||
def _get_spkr_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkrs)}
|
||||
|
||||
def sample_prompts(self, spkr_name):
|
||||
def sample_prompts(self, spkr_name, better_not):
|
||||
prom_list = []
|
||||
|
||||
while (
|
||||
len(prom_list) == 0
|
||||
or random.random() < cfg.p_additional_prompt
|
||||
and len(prom_list) < 10
|
||||
):
|
||||
path = random.choice(self.paths_by_spkr_name[spkr_name])
|
||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {better_not}
|
||||
choices = [*choices]
|
||||
|
||||
if len(choices) == 0:
|
||||
_logger.info(
|
||||
f"Failed to find another different utterance for {spkr_name}, "
|
||||
"using the same audio as prompt."
|
||||
)
|
||||
choices = [better_not]
|
||||
|
||||
for _ in range(10):
|
||||
path = random.choice(choices)
|
||||
prom_list.append(_load_quants(path))
|
||||
if random.random() > cfg.p_additional_prompt:
|
||||
break
|
||||
|
||||
prom = torch.cat(prom_list)
|
||||
|
||||
@ -152,7 +160,7 @@ class VALLEDatset(Dataset):
|
||||
|
||||
spkr_name = _get_spkr_name(path)
|
||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
|
||||
proms = self.sample_prompts(spkr_name)
|
||||
proms = self.sample_prompts(spkr_name, better_not=path)
|
||||
resps = _load_quants(path)
|
||||
resp = resps[..., 0]
|
||||
|
||||
|
@ -40,6 +40,8 @@ def main():
|
||||
|
||||
for path in tqdm(paths):
|
||||
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
||||
if phone_path.exists():
|
||||
continue
|
||||
graphs = _get_graphs(path)
|
||||
phones = encode(graphs)
|
||||
with open(phone_path, "w") as f:
|
||||
|
Loading…
Reference in New Issue
Block a user