From 6ca347e1e14d78c4c12d60e3639c348adc165220 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 19 Aug 2023 01:16:46 -0500 Subject: [PATCH] literally had a urethra moment before going to bed with a way to implement cse/nse tasks --- vall_e/data.py | 91 ++++++++++++++++++++++++++++++++++++++++++------- vall_e/train.py | 5 ++- 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 2d54080..241c90e 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -300,6 +300,7 @@ class Dataset(_Dataset): task = random.choice(self.tasks) + noise_scale = 0.125 # text-to-speech if task == "tts": proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps @@ -307,11 +308,10 @@ class Dataset(_Dataset): elif task == "ns" or task == "sr": # sample random noise noise = self.sample_noise() - # extend the noise to fill the target audio noise = repeat_extend_audio(noise, resps.shape[0]) # create the input prompt by merging the target audio with the noise - proms = merge_audio(resps, noise, scale=[1, 0.125], device="cpu") + proms = merge_audio(resps, noise, scale=[1, noise_scale], device="cpu") # set the target to just be the noise if if task == "sr": resps = noise @@ -324,7 +324,7 @@ class Dataset(_Dataset): # target speech extraction elif task == "tse": # sample a random, clean, utterance for the target speaker - clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + clean_proms = self.sample_prompts(spkr_name, ignore=path) # sample a random, clean utterance from a different speaker other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="") # overlay the random speaker over the target audio @@ -342,19 +342,86 @@ class Dataset(_Dataset): # set the text prompt to empty to train without a guided text prompt if random.random() < 0.5: - text = torch.tensor([1, 2]).to(self.text_dtype) + text = torch.tensor([1, 2]).to(self.text_dtype) # # speech editing would require higher quality transcription data (phoneme level/word level) unfortunately # as I need to get a good clean point to trim into - """ # clean speech editing - elif task == "cse": - ... - # noisy speech editing - elif task == "nse": - ... - """ - + elif task == "cse" or task == "nse": + choices = set(self.paths_by_spkr_name[spkr_name]) - {path} + sampled = random.choice([*choices], 4) + + if cfg.dataset.use_hdf5: + texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ] + qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) for path in sampled ] + else: + texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ] + qnts = [ _load_quants(path) for path in sampled ] + + # remove + for text in texts: + text = text[1:-1] + + pre_text, mid_text, post_text, edit_text = texts + pre_prom, mid_prom, post_prom, edit_prom = qnts + + # randomly drop out pre + if random.random() < 0.125: + pre_text = None + pre_prom = None + # randomly drop out post + if random.random() < 0.125: + post_text = None + post_prom = None + + # create new text + text = torch.cat( + [ 1 ] + # + ([ pre_text ] if pre_text is not None else []) + + [ edit_text ] + + ([ post_post ] if post_post is not None else []) + + [ 2 ] # + ) + + if task == "nse": + # sample random noise + noise = self.sample_noise() + + # it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent + # but it's noise, it's supposed to be random + def noise_proms( proms ): + # ignore if we turned it off + if proms is None: + return None + + # extend the noise to fill the target audio + n = repeat_extend_audio(noise, proms.shape[0]) + # merge the noise over the utterance + return merge_audio(proms, noise, scale=[1, noise_scale], device="cpu") + + # apply noise to all pieces + pre_prom = noise_proms( pre_prom ) + mid_prom = noise_proms( mid_prom ) + post_prom = noise_proms( post_prom ) + edit_prom = noise_proms( edit_prom ) + else: + mid_prom = self.get_task_token("mask") + + # create new proms + proms = torch.cat( + ([ pre_prom ] if pre_prom is not None else []) + + [self.get_task_token("soe")] + + [ mid_prom ] + # is if task is CSE + [self.get_task_token("eoe")] + + ([ post_prom ] if post_prom is not None else []) + ) + # create new resp + resps = torch.cat( + ([ pre_prom ] if pre_prom is not None else []) + + [ edit_prom ] + + ([ post_prom ] if post_prom is not None else []) + ) + """ # emulate SVC # takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt diff --git a/vall_e/train.py b/vall_e/train.py index f59c4ec..052986a 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -57,12 +57,15 @@ def run_eval(engines, eval_name, dl): stats['loss'] = [] def process( name, batch, resps_list ): - for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]): + for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]): if len(hyp) == 0: continue filename = f'{speaker}_{path.parts[-1]}' + if task != "tts": + filename = f"{filename}_{task}" + # to-do, refine the output dir to be sane-er ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")