From 6045cbce945c2a2301ca2f9fe21b9ff5f21eeae5 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 11 Oct 2023 17:32:45 -0500 Subject: [PATCH] added experimental option to append utterances for training target (emphasis on experimental) --- vall_e/config.py | 3 +++ vall_e/data.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/vall_e/config.py b/vall_e/config.py index 2e6204e..8fdd591 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -136,6 +136,9 @@ class Dataset: random_utterance: float = 1.0 max_prompts: int = 3 prompt_duration: float = 3.0 + + max_resps: int = 1 + p_resp_append: float = 1.0 sample_type: str = "path" # path | speaker diff --git a/vall_e/data.py b/vall_e/data.py index 08244c9..15833a6 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -319,6 +319,8 @@ class Dataset(_Dataset): if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance: break + # might be better to decode => concat waveforms with silence in between => reencode + # as you technically can't just append encodec sequences together like this without issues prom = torch.cat(prom_list) if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]: @@ -348,11 +350,41 @@ class Dataset(_Dataset): else: text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) resps = _load_quants(path) + + # append additional prompts in an attempt to artifically increase lengths / offer new data + if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append: + choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})] + + if len(choices) > 0: + for _ in range( cfg.dataset.max_resps - 1 ): + sampled_path = random.choice(choices) + choices = [*(set(choices) - {sampled_path})] + if cfg.dataset.use_hdf5: + key = _get_hdf5_path(path) + txt = cfg.hdf5[key]["text"][:] + qnt = cfg.hdf5[key]["audio"][:, :] + + txt = np.array( _cleanup_phones( txt, targets=[ self.phone_symmap[" "] ] ) ) + + txt = torch.from_numpy(txt).to(self.text_dtype) + qnt = torch.from_numpy(qnt).to(torch.int16) + else: + txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype) + qnt = _load_quants(sampled_path) + + # [original text] [new text] + # removes the original text's , includes a space, and remove the new text's + text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ]) + + # might be better to decode => concat waveforms with silence in between => reencode + # as you technically can't just append encodec sequences together like this without issues + resps = torch.concat([ resps, qnt ]) task = "tts" trim_length = int(cfg.dataset.prompt_duration * 75) proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + # Disabled until I swap over to a better method """ task = random.choice(self.tasks)