From 8e7f900210031e067df2bd7cb13c22716d26b958 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 17 Aug 2023 19:07:59 -0500 Subject: [PATCH] forgot the = --- vall_e/data.py | 27 ++++++++++++++++++++++++++- vall_e/train.py | 2 +- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index eef2009..0c72ea0 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -196,6 +196,10 @@ class Dataset(_Dataset): def _get_spkr_symmap(self): return {s: i for i, s in enumerate(self.spkrs)} + def sample_speakers(self, ignore=[]): + choices = set(self.spkrs) - set(ignore) + return random.choice([*choices]) + def sample_prompts(self, spkr_name, ignore): prom_list = [] @@ -279,9 +283,30 @@ class Dataset(_Dataset): resps = _load_quants(path) task = random.choice(self.tasks) + # text-to-speech if task == "tts": - # I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + # noise-suppression + """ + elif task == "ns": + proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + noise = self.sample_noise() + noise = extend_audio(noise, proms.shape[0]) + proms = merge_audio(proms, noise) + # something to prepend a ns token to the beginning of proms + elif task == "sr": + proms = resps + resps = self.sample_noise() + resps = extend_audio(resps, proms.shape[0]) + # something to prepend a sr token to the beginning of proms + elif task == "tse: + proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + other_speaker = self.sample_speaker(ignore=[spkr_name]) + other_proms = self.sample_prompts(other_speaker, ignore="") + proms = merge_audio(proms, other_proms) + # something to prepend a ns token to the beginning of proms + """ + return dict( index=index, diff --git a/vall_e/train.py b/vall_e/train.py index 3799b5e..a53f58a 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl): process( name, batch, resps_list ) processed += len(batch["text"]) - if processed > cfg.evaluation.size: + if processed >= cfg.evaluation.size: break stats = {k: sum(v) / len(v) for k, v in stats.items()}