added experimental option to append utterances for training target (emphasis on experimental)

This commit is contained in:
mrq 2023-10-11 17:32:45 -05:00
parent 7facacf7c9
commit 6045cbce94
2 changed files with 35 additions and 0 deletions

View File

@ -137,6 +137,9 @@ class Dataset:
max_prompts: int = 3
prompt_duration: float = 3.0
max_resps: int = 1
p_resp_append: float = 1.0
sample_type: str = "path" # path | speaker
tasks_list: list[str] = field(default_factory=lambda: ["tts"])

View File

@ -319,6 +319,8 @@ class Dataset(_Dataset):
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
break
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
prom = torch.cat(prom_list)
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
@ -349,10 +351,40 @@ class Dataset(_Dataset):
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
resps = _load_quants(path)
# append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
if len(choices) > 0:
for _ in range( cfg.dataset.max_resps - 1 ):
sampled_path = random.choice(choices)
choices = [*(set(choices) - {sampled_path})]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
txt = cfg.hdf5[key]["text"][:]
qnt = cfg.hdf5[key]["audio"][:, :]
txt = np.array( _cleanup_phones( txt, targets=[ self.phone_symmap[" "] ] ) )
txt = torch.from_numpy(txt).to(self.text_dtype)
qnt = torch.from_numpy(qnt).to(torch.int16)
else:
txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
qnt = _load_quants(sampled_path)
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = torch.concat([ resps, qnt ])
task = "tts"
trim_length = int(cfg.dataset.prompt_duration * 75)
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# Disabled until I swap over to a better method
"""
task = random.choice(self.tasks)