literally had a urethra moment before going to bed with a way to implement cse/nse tasks

This commit is contained in:
mrq 2023-08-19 01:16:46 -05:00
parent 8f42c578c9
commit 6ca347e1e1
2 changed files with 83 additions and 13 deletions

View File

@ -300,6 +300,7 @@ class Dataset(_Dataset):
task = random.choice(self.tasks)
noise_scale = 0.125
# text-to-speech
if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
@ -307,11 +308,10 @@ class Dataset(_Dataset):
elif task == "ns" or task == "sr":
# sample random noise
noise = self.sample_noise()
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio(resps, noise, scale=[1, 0.125], device="cpu")
proms = merge_audio(resps, noise, scale=[1, noise_scale], device="cpu")
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
@ -324,7 +324,7 @@ class Dataset(_Dataset):
# target speech extraction
elif task == "tse":
# sample a random, clean, utterance for the target speaker
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
clean_proms = self.sample_prompts(spkr_name, ignore=path)
# sample a random, clean utterance from a different speaker
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
# overlay the random speaker over the target audio
@ -342,19 +342,86 @@ class Dataset(_Dataset):
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(self.text_dtype)
text = torch.tensor([1, 2]).to(self.text_dtype) # <s></s>
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
"""
# clean speech editing
elif task == "cse":
...
# noisy speech editing
elif task == "nse":
...
"""
elif task == "cse" or task == "nse":
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
sampled = random.choice([*choices], 4)
if cfg.dataset.use_hdf5:
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) for path in sampled ]
else:
texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ]
qnts = [ _load_quants(path) for path in sampled ]
# remove <s></s>
for text in texts:
text = text[1:-1]
pre_text, mid_text, post_text, edit_text = texts
pre_prom, mid_prom, post_prom, edit_prom = qnts
# randomly drop out pre
if random.random() < 0.125:
pre_text = None
pre_prom = None
# randomly drop out post
if random.random() < 0.125:
post_text = None
post_prom = None
# create new text
text = torch.cat(
[ 1 ] + # <s>
([ pre_text ] if pre_text is not None else []) +
[ edit_text ] +
([ post_post ] if post_post is not None else []) +
[ 2 ] # </s>
)
if task == "nse":
# sample random noise
noise = self.sample_noise()
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
# but it's noise, it's supposed to be random
def noise_proms( proms ):
# ignore if we turned it off
if proms is None:
return None
# extend the noise to fill the target audio
n = repeat_extend_audio(noise, proms.shape[0])
# merge the noise over the utterance
return merge_audio(proms, noise, scale=[1, noise_scale], device="cpu")
# apply noise to all pieces
pre_prom = noise_proms( pre_prom )
mid_prom = noise_proms( mid_prom )
post_prom = noise_proms( post_prom )
edit_prom = noise_proms( edit_prom )
else:
mid_prom = self.get_task_token("mask")
# create new proms
proms = torch.cat(
([ pre_prom ] if pre_prom is not None else []) +
[self.get_task_token("soe")] +
[ mid_prom ] + # is <mask> if task is CSE
[self.get_task_token("eoe")] +
([ post_prom ] if post_prom is not None else [])
)
# create new resp
resps = torch.cat(
([ pre_prom ] if pre_prom is not None else []) +
[ edit_prom ] +
([ post_prom ] if post_prom is not None else [])
)
"""
# emulate SVC
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt

View File

@ -57,12 +57,15 @@ def run_eval(engines, eval_name, dl):
stats['loss'] = []
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
if task != "tts":
filename = f"{filename}_{task}"
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")