literally had a urethra moment before going to bed with a way to implement cse/nse tasks
This commit is contained in:
parent
8f42c578c9
commit
6ca347e1e1
|
@ -300,6 +300,7 @@ class Dataset(_Dataset):
|
|||
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
noise_scale = 0.125
|
||||
# text-to-speech
|
||||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
|
@ -307,11 +308,10 @@ class Dataset(_Dataset):
|
|||
elif task == "ns" or task == "sr":
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
|
||||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio(resps, noise, scale=[1, 0.125], device="cpu")
|
||||
proms = merge_audio(resps, noise, scale=[1, noise_scale], device="cpu")
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
|
@ -324,7 +324,7 @@ class Dataset(_Dataset):
|
|||
# target speech extraction
|
||||
elif task == "tse":
|
||||
# sample a random, clean, utterance for the target speaker
|
||||
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
clean_proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
# sample a random, clean utterance from a different speaker
|
||||
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||
# overlay the random speaker over the target audio
|
||||
|
@ -342,19 +342,86 @@ class Dataset(_Dataset):
|
|||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||
text = torch.tensor([1, 2]).to(self.text_dtype) # <s></s>
|
||||
|
||||
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
||||
# as I need to get a good clean point to trim into
|
||||
"""
|
||||
# clean speech editing
|
||||
elif task == "cse":
|
||||
...
|
||||
# noisy speech editing
|
||||
elif task == "nse":
|
||||
...
|
||||
"""
|
||||
|
||||
elif task == "cse" or task == "nse":
|
||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
|
||||
sampled = random.choice([*choices], 4)
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
|
||||
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) for path in sampled ]
|
||||
else:
|
||||
texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ]
|
||||
qnts = [ _load_quants(path) for path in sampled ]
|
||||
|
||||
# remove <s></s>
|
||||
for text in texts:
|
||||
text = text[1:-1]
|
||||
|
||||
pre_text, mid_text, post_text, edit_text = texts
|
||||
pre_prom, mid_prom, post_prom, edit_prom = qnts
|
||||
|
||||
# randomly drop out pre
|
||||
if random.random() < 0.125:
|
||||
pre_text = None
|
||||
pre_prom = None
|
||||
# randomly drop out post
|
||||
if random.random() < 0.125:
|
||||
post_text = None
|
||||
post_prom = None
|
||||
|
||||
# create new text
|
||||
text = torch.cat(
|
||||
[ 1 ] + # <s>
|
||||
([ pre_text ] if pre_text is not None else []) +
|
||||
[ edit_text ] +
|
||||
([ post_post ] if post_post is not None else []) +
|
||||
[ 2 ] # </s>
|
||||
)
|
||||
|
||||
if task == "nse":
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
|
||||
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
|
||||
# but it's noise, it's supposed to be random
|
||||
def noise_proms( proms ):
|
||||
# ignore if we turned it off
|
||||
if proms is None:
|
||||
return None
|
||||
|
||||
# extend the noise to fill the target audio
|
||||
n = repeat_extend_audio(noise, proms.shape[0])
|
||||
# merge the noise over the utterance
|
||||
return merge_audio(proms, noise, scale=[1, noise_scale], device="cpu")
|
||||
|
||||
# apply noise to all pieces
|
||||
pre_prom = noise_proms( pre_prom )
|
||||
mid_prom = noise_proms( mid_prom )
|
||||
post_prom = noise_proms( post_prom )
|
||||
edit_prom = noise_proms( edit_prom )
|
||||
else:
|
||||
mid_prom = self.get_task_token("mask")
|
||||
|
||||
# create new proms
|
||||
proms = torch.cat(
|
||||
([ pre_prom ] if pre_prom is not None else []) +
|
||||
[self.get_task_token("soe")] +
|
||||
[ mid_prom ] + # is <mask> if task is CSE
|
||||
[self.get_task_token("eoe")] +
|
||||
([ post_prom ] if post_prom is not None else [])
|
||||
)
|
||||
# create new resp
|
||||
resps = torch.cat(
|
||||
([ pre_prom ] if pre_prom is not None else []) +
|
||||
[ edit_prom ] +
|
||||
([ post_prom ] if post_prom is not None else [])
|
||||
)
|
||||
|
||||
"""
|
||||
# emulate SVC
|
||||
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
|
||||
|
|
|
@ -57,12 +57,15 @@ def run_eval(engines, eval_name, dl):
|
|||
stats['loss'] = []
|
||||
|
||||
def process( name, batch, resps_list ):
|
||||
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
|
||||
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
|
||||
if len(hyp) == 0:
|
||||
continue
|
||||
|
||||
filename = f'{speaker}_{path.parts[-1]}'
|
||||
|
||||
if task != "tts":
|
||||
filename = f"{filename}_{task}"
|
||||
|
||||
# to-do, refine the output dir to be sane-er
|
||||
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
|
||||
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
|
||||
|
|
Loading…
Reference in New Issue
Block a user