forgot the =

This commit is contained in:
mrq 2023-08-17 19:07:59 -05:00
parent 3ff7cf8341
commit 8e7f900210
2 changed files with 27 additions and 2 deletions

View File

@ -196,6 +196,10 @@ class Dataset(_Dataset):
def _get_spkr_symmap(self):
return {s: i for i, s in enumerate(self.spkrs)}
def sample_speakers(self, ignore=[]):
choices = set(self.spkrs) - set(ignore)
return random.choice([*choices])
def sample_prompts(self, spkr_name, ignore):
prom_list = []
@ -279,9 +283,30 @@ class Dataset(_Dataset):
resps = _load_quants(path)
task = random.choice(self.tasks)
# text-to-speech
if task == "tts":
# I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# noise-suppression
"""
elif task == "ns":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
noise = self.sample_noise()
noise = extend_audio(noise, proms.shape[0])
proms = merge_audio(proms, noise)
# something to prepend a ns token to the beginning of proms
elif task == "sr":
proms = resps
resps = self.sample_noise()
resps = extend_audio(resps, proms.shape[0])
# something to prepend a sr token to the beginning of proms
elif task == "tse:
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
other_speaker = self.sample_speaker(ignore=[spkr_name])
other_proms = self.sample_prompts(other_speaker, ignore="")
proms = merge_audio(proms, other_proms)
# something to prepend a ns token to the beginning of proms
"""
return dict(
index=index,

View File

@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl):
process( name, batch, resps_list )
processed += len(batch["text"])
if processed > cfg.evaluation.size:
if processed >= cfg.evaluation.size:
break
stats = {k: sum(v) / len(v) for k, v in stats.items()}