forgot the =
This commit is contained in:
parent
3ff7cf8341
commit
8e7f900210
|
@ -196,6 +196,10 @@ class Dataset(_Dataset):
|
||||||
def _get_spkr_symmap(self):
|
def _get_spkr_symmap(self):
|
||||||
return {s: i for i, s in enumerate(self.spkrs)}
|
return {s: i for i, s in enumerate(self.spkrs)}
|
||||||
|
|
||||||
|
def sample_speakers(self, ignore=[]):
|
||||||
|
choices = set(self.spkrs) - set(ignore)
|
||||||
|
return random.choice([*choices])
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name, ignore):
|
def sample_prompts(self, spkr_name, ignore):
|
||||||
prom_list = []
|
prom_list = []
|
||||||
|
|
||||||
|
@ -279,9 +283,30 @@ class Dataset(_Dataset):
|
||||||
resps = _load_quants(path)
|
resps = _load_quants(path)
|
||||||
|
|
||||||
task = random.choice(self.tasks)
|
task = random.choice(self.tasks)
|
||||||
|
# text-to-speech
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
# I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing
|
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
# noise-suppression
|
||||||
|
"""
|
||||||
|
elif task == "ns":
|
||||||
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
noise = self.sample_noise()
|
||||||
|
noise = extend_audio(noise, proms.shape[0])
|
||||||
|
proms = merge_audio(proms, noise)
|
||||||
|
# something to prepend a ns token to the beginning of proms
|
||||||
|
elif task == "sr":
|
||||||
|
proms = resps
|
||||||
|
resps = self.sample_noise()
|
||||||
|
resps = extend_audio(resps, proms.shape[0])
|
||||||
|
# something to prepend a sr token to the beginning of proms
|
||||||
|
elif task == "tse:
|
||||||
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
other_speaker = self.sample_speaker(ignore=[spkr_name])
|
||||||
|
other_proms = self.sample_prompts(other_speaker, ignore="")
|
||||||
|
proms = merge_audio(proms, other_proms)
|
||||||
|
# something to prepend a ns token to the beginning of proms
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
|
|
|
@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
processed += len(batch["text"])
|
processed += len(batch["text"])
|
||||||
if processed > cfg.evaluation.size:
|
if processed >= cfg.evaluation.size:
|
||||||
break
|
break
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user