Better prompt sampling
This commit is contained in:
parent
3a4d5be18b
commit
958c2df660
@ -72,7 +72,7 @@ def _validate(path, min_phones, max_phones):
|
|||||||
|
|
||||||
|
|
||||||
def _get_spkr_name(path) -> str:
|
def _get_spkr_name(path) -> str:
|
||||||
return path.parts[-2] # spkr/*.wav
|
return cfg.get_spkr(path)
|
||||||
|
|
||||||
|
|
||||||
class VALLEDatset(Dataset):
|
class VALLEDatset(Dataset):
|
||||||
@ -128,16 +128,24 @@ class VALLEDatset(Dataset):
|
|||||||
def _get_spkr_symmap(self):
|
def _get_spkr_symmap(self):
|
||||||
return {s: i for i, s in enumerate(self.spkrs)}
|
return {s: i for i, s in enumerate(self.spkrs)}
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name):
|
def sample_prompts(self, spkr_name, better_not):
|
||||||
prom_list = []
|
prom_list = []
|
||||||
|
|
||||||
while (
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {better_not}
|
||||||
len(prom_list) == 0
|
choices = [*choices]
|
||||||
or random.random() < cfg.p_additional_prompt
|
|
||||||
and len(prom_list) < 10
|
if len(choices) == 0:
|
||||||
):
|
_logger.info(
|
||||||
path = random.choice(self.paths_by_spkr_name[spkr_name])
|
f"Failed to find another different utterance for {spkr_name}, "
|
||||||
|
"using the same audio as prompt."
|
||||||
|
)
|
||||||
|
choices = [better_not]
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
path = random.choice(choices)
|
||||||
prom_list.append(_load_quants(path))
|
prom_list.append(_load_quants(path))
|
||||||
|
if random.random() > cfg.p_additional_prompt:
|
||||||
|
break
|
||||||
|
|
||||||
prom = torch.cat(prom_list)
|
prom = torch.cat(prom_list)
|
||||||
|
|
||||||
@ -152,7 +160,7 @@ class VALLEDatset(Dataset):
|
|||||||
|
|
||||||
spkr_name = _get_spkr_name(path)
|
spkr_name = _get_spkr_name(path)
|
||||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
|
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
|
||||||
proms = self.sample_prompts(spkr_name)
|
proms = self.sample_prompts(spkr_name, better_not=path)
|
||||||
resps = _load_quants(path)
|
resps = _load_quants(path)
|
||||||
resp = resps[..., 0]
|
resp = resps[..., 0]
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ def main():
|
|||||||
|
|
||||||
for path in tqdm(paths):
|
for path in tqdm(paths):
|
||||||
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
||||||
|
if phone_path.exists():
|
||||||
|
continue
|
||||||
graphs = _get_graphs(path)
|
graphs = _get_graphs(path)
|
||||||
phones = encode(graphs)
|
phones = encode(graphs)
|
||||||
with open(phone_path, "w") as f:
|
with open(phone_path, "w") as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user