forgot the =
This commit is contained in:
parent
3ff7cf8341
commit
8e7f900210
|
@ -196,6 +196,10 @@ class Dataset(_Dataset):
|
|||
def _get_spkr_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkrs)}
|
||||
|
||||
def sample_speakers(self, ignore=[]):
|
||||
choices = set(self.spkrs) - set(ignore)
|
||||
return random.choice([*choices])
|
||||
|
||||
def sample_prompts(self, spkr_name, ignore):
|
||||
prom_list = []
|
||||
|
||||
|
@ -279,9 +283,30 @@ class Dataset(_Dataset):
|
|||
resps = _load_quants(path)
|
||||
|
||||
task = random.choice(self.tasks)
|
||||
# text-to-speech
|
||||
if task == "tts":
|
||||
# I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# noise-suppression
|
||||
"""
|
||||
elif task == "ns":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
noise = self.sample_noise()
|
||||
noise = extend_audio(noise, proms.shape[0])
|
||||
proms = merge_audio(proms, noise)
|
||||
# something to prepend a ns token to the beginning of proms
|
||||
elif task == "sr":
|
||||
proms = resps
|
||||
resps = self.sample_noise()
|
||||
resps = extend_audio(resps, proms.shape[0])
|
||||
# something to prepend a sr token to the beginning of proms
|
||||
elif task == "tse:
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
other_speaker = self.sample_speaker(ignore=[spkr_name])
|
||||
other_proms = self.sample_prompts(other_speaker, ignore="")
|
||||
proms = merge_audio(proms, other_proms)
|
||||
# something to prepend a ns token to the beginning of proms
|
||||
"""
|
||||
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
|
|
|
@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl):
|
|||
process( name, batch, resps_list )
|
||||
|
||||
processed += len(batch["text"])
|
||||
if processed > cfg.evaluation.size:
|
||||
if processed >= cfg.evaluation.size:
|
||||
break
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
|
|
Loading…
Reference in New Issue
Block a user