fixes...
This commit is contained in:
parent
39f961abcd
commit
28a674e0f1
103
vall_e/data.py
103
vall_e/data.py
|
@ -275,6 +275,9 @@ def get_task_symmap():
|
|||
"<soe>": 5,
|
||||
"<mask>": 6,
|
||||
"<eoe>": 7,
|
||||
|
||||
"<nse>": 6, # fake
|
||||
"<cse>": 6, # fake
|
||||
}
|
||||
|
||||
def _replace_file_extension(path, suffix):
|
||||
|
@ -849,12 +852,12 @@ class Dataset(_Dataset):
|
|||
if f'<{task}>' not in self.task_symmap:
|
||||
raise Exception(f'Task not defined: {task}')
|
||||
|
||||
# Base TTS (text + prompt => output)
|
||||
# Base TTS (<text><prompt> => <resp>)
|
||||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
|
||||
# VALL-E Continuous (text + partial output => rest of output)
|
||||
# (this could just be sampled as <text a><text b> + <audio a> => <audio b>, but I need to experiment with it)
|
||||
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
|
||||
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
|
||||
elif task == "tts-c":
|
||||
# trim a piece of the output response
|
||||
if naive:
|
||||
|
@ -871,14 +874,21 @@ class Dataset(_Dataset):
|
|||
# <s>[original text] [new text]</s>
|
||||
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
||||
else:
|
||||
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
|
||||
text = torch.concat([ text[:-1], torch.tensor([space_id]).to(torch.int16), txt[1:] ])
|
||||
|
||||
# set prompt as initial response
|
||||
proms = resps
|
||||
# set target as newly sampled response
|
||||
resps = qnt
|
||||
|
||||
# noise suppression || speech removal
|
||||
# inject task token
|
||||
proms = [
|
||||
proms,
|
||||
task,
|
||||
]
|
||||
|
||||
# noise suppression (<text>? <resp+noise> => <resp>)
|
||||
# speech removal (<text>?<resp+noise> => <noise>)
|
||||
elif task == "ns" or task == "sr":
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
|
@ -886,40 +896,44 @@ class Dataset(_Dataset):
|
|||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = None
|
||||
|
||||
# inject task token
|
||||
proms = [
|
||||
task,
|
||||
proms
|
||||
]
|
||||
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
|
||||
|
||||
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
||||
elif task == "tse":
|
||||
# sample a prompt
|
||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
|
||||
# sample another speaker
|
||||
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
|
||||
|
||||
# overlay the random speaker over the target audio
|
||||
other_resps = merge_audio( resps, other_resps, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||
|
||||
# target speech extraction
|
||||
elif task == "tse":
|
||||
# sample a random, clean, utterance for the target speaker
|
||||
clean_proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
# sample a random, clean utterance from a different speaker
|
||||
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||
|
||||
# overlay the random speaker over the target audio
|
||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||
if other_proms.shape[0] == smallest_size:
|
||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||
else:
|
||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||
text = None
|
||||
|
||||
# stitch together the proms
|
||||
proms = [
|
||||
clean_proms,
|
||||
proms,
|
||||
task,
|
||||
noisy_proms,
|
||||
other_resps,
|
||||
]
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||
|
||||
# clean speech editing
|
||||
elif task == "cse" or task == "nse":
|
||||
|
@ -940,17 +954,21 @@ class Dataset(_Dataset):
|
|||
pre_text = None
|
||||
pre_prom = None
|
||||
# randomly drop out post
|
||||
if random.random() < 0.125:
|
||||
elif random.random() < 0.125:
|
||||
post_text = None
|
||||
post_prom = None
|
||||
|
||||
# create new text
|
||||
text = torch.cat(
|
||||
[ torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype) ] + # <s>
|
||||
([ pre_text, torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
|
||||
[ edit_text ] + # 'edit text'
|
||||
([ torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
|
||||
[ torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype) ] # </s>
|
||||
text = concat_audio(
|
||||
torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
|
||||
pre_text,
|
||||
None if pre_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||
edit_text,
|
||||
None if post_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||
post_text,
|
||||
torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
|
||||
|
||||
reencode=False,
|
||||
)
|
||||
|
||||
if task == "nse":
|
||||
|
@ -978,23 +996,26 @@ class Dataset(_Dataset):
|
|||
# create new prom
|
||||
proms = [
|
||||
pre_prom,
|
||||
"<soe>",
|
||||
"<mask>" if task == "cse" else mid_prom,
|
||||
"<eoe>",
|
||||
"soe",
|
||||
"mask" if task == "cse" else mid_prom,
|
||||
"eoe",
|
||||
post_prom,
|
||||
]
|
||||
|
||||
# create new resp
|
||||
resps = concat_audio(
|
||||
*(([ pre_prom ] if pre_prom is not None else []) +
|
||||
[ edit_prom ] +
|
||||
([ post_prom ] if post_prom is not None else [])),
|
||||
pre_prom,
|
||||
edit_prom,
|
||||
post_prom,
|
||||
reencode=cfg.dataset.reencode_on_concat,
|
||||
device=cfg.dataset.reencode_device,
|
||||
)
|
||||
else:
|
||||
raise Exception(f'Undefined task: {task}')
|
||||
|
||||
if text is None:
|
||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
path=Path(path),
|
||||
|
|
|
@ -482,7 +482,9 @@ def repeat_extend_audio( qnt, target ):
|
|||
# interleaves between a list of audios
|
||||
# useful for interleaving silence
|
||||
def interleave_audio( *args, audio=None ):
|
||||
qnts = [*args]
|
||||
qnts = [ *args ]
|
||||
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||
|
||||
if audio is None:
|
||||
return qnts
|
||||
|
||||
|
@ -498,7 +500,8 @@ def interleave_audio( *args, audio=None ):
|
|||
|
||||
# concats two audios together
|
||||
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
|
||||
qnts = [*args]
|
||||
qnts = [ *args ]
|
||||
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||
# just naively combine the codes
|
||||
if not reencode:
|
||||
return torch.concat( qnts )
|
||||
|
@ -510,9 +513,19 @@ def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_lev
|
|||
# merges two quantized audios together
|
||||
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
|
||||
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
|
||||
qnts = [*args]
|
||||
qnts = [ *args ]
|
||||
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||
|
||||
# max length
|
||||
max_length = max([ wav.shape[-1] for wav in decoded ])
|
||||
for i, wav in enumerate(decoded):
|
||||
delta = max_length - wav.shape[-1]
|
||||
if delta <= 0:
|
||||
continue
|
||||
pad = torch.zeros( (1, delta), dtype=wav.dtype, device=wav.device )
|
||||
decoded[i] = torch.cat( [ wav, pad ], dim=-1 )
|
||||
|
||||
# useful to adjust the volumes of each waveform
|
||||
if len(scale) == len(decoded):
|
||||
for i in range(len(scale)):
|
||||
|
|
|
@ -945,8 +945,8 @@ class Base(nn.Module):
|
|||
):
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_embedding( input, quant_level ):
|
||||
if isinstance(inputs, str):
|
||||
return self.tasks_emb( get_task_symmap()[f'<{input}>'] ) if self.tasks_emb is None else None
|
||||
if isinstance(input, str):
|
||||
return self.tasks_emb( torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(device=device, dtype=torch.int16) )
|
||||
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
if self.version <= 4:
|
||||
|
@ -958,6 +958,7 @@ class Base(nn.Module):
|
|||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = []
|
||||
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
||||
|
||||
task_type = "tts"
|
||||
for name, input in batch_input:
|
||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
|
@ -971,13 +972,15 @@ class Base(nn.Module):
|
|||
continue
|
||||
elif name == "text":
|
||||
embedding = self.text_emb( input )
|
||||
|
||||
device = embedding.device
|
||||
elif name == "quant_level" and self.rvq_l_emb is not None:
|
||||
embedding = self.rvq_l_emb( input )
|
||||
elif name == "lang" and self.langs_emb is not None:
|
||||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms ] )
|
||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
|
@ -1024,8 +1027,23 @@ class Base(nn.Module):
|
|||
# there's a better way
|
||||
if not self.unified_position_ids:
|
||||
x_list = []
|
||||
|
||||
def get_input_token_length( name, input ):
|
||||
# task token
|
||||
if isinstance(input, str):
|
||||
return 1
|
||||
|
||||
# list of tokens
|
||||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||
|
||||
return input.shape[0] + (0 if name == "resp" else 1)
|
||||
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = torch.cat( [ torch.Tensor([*range( input.shape[0] + (0 if name == "resp" else 1) )]) for name, input in batch_input if name != "task" ] )
|
||||
batch = torch.cat( [
|
||||
torch.Tensor([*range(get_input_token_length(name, input))])
|
||||
for name, input in batch_input if name != "task"
|
||||
] )
|
||||
|
||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||
if delta > 0:
|
||||
|
@ -1044,10 +1062,12 @@ class Base(nn.Module):
|
|||
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
):
|
||||
device = logits[0].device
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
if isinstance(inputs, str):
|
||||
return get_task_symmap()[f'<{input}>']
|
||||
if isinstance(input, str):
|
||||
return torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
|
@ -1068,7 +1088,7 @@ class Base(nn.Module):
|
|||
task_list.append( input )
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) )
|
||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||
elif name == "resp":
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user