fixes...
This commit is contained in:
parent
39f961abcd
commit
28a674e0f1
|
@ -275,6 +275,9 @@ def get_task_symmap():
|
||||||
"<soe>": 5,
|
"<soe>": 5,
|
||||||
"<mask>": 6,
|
"<mask>": 6,
|
||||||
"<eoe>": 7,
|
"<eoe>": 7,
|
||||||
|
|
||||||
|
"<nse>": 6, # fake
|
||||||
|
"<cse>": 6, # fake
|
||||||
}
|
}
|
||||||
|
|
||||||
def _replace_file_extension(path, suffix):
|
def _replace_file_extension(path, suffix):
|
||||||
|
@ -849,12 +852,12 @@ class Dataset(_Dataset):
|
||||||
if f'<{task}>' not in self.task_symmap:
|
if f'<{task}>' not in self.task_symmap:
|
||||||
raise Exception(f'Task not defined: {task}')
|
raise Exception(f'Task not defined: {task}')
|
||||||
|
|
||||||
# Base TTS (text + prompt => output)
|
# Base TTS (<text><prompt> => <resp>)
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||||
|
|
||||||
# VALL-E Continuous (text + partial output => rest of output)
|
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
|
||||||
# (this could just be sampled as <text a><text b> + <audio a> => <audio b>, but I need to experiment with it)
|
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
|
||||||
elif task == "tts-c":
|
elif task == "tts-c":
|
||||||
# trim a piece of the output response
|
# trim a piece of the output response
|
||||||
if naive:
|
if naive:
|
||||||
|
@ -871,14 +874,21 @@ class Dataset(_Dataset):
|
||||||
# <s>[original text] [new text]</s>
|
# <s>[original text] [new text]</s>
|
||||||
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
||||||
else:
|
else:
|
||||||
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
|
text = torch.concat([ text[:-1], torch.tensor([space_id]).to(torch.int16), txt[1:] ])
|
||||||
|
|
||||||
# set prompt as initial response
|
# set prompt as initial response
|
||||||
proms = resps
|
proms = resps
|
||||||
# set target as newly sampled response
|
# set target as newly sampled response
|
||||||
resps = qnt
|
resps = qnt
|
||||||
|
|
||||||
# noise suppression || speech removal
|
# inject task token
|
||||||
|
proms = [
|
||||||
|
proms,
|
||||||
|
task,
|
||||||
|
]
|
||||||
|
|
||||||
|
# noise suppression (<text>? <resp+noise> => <resp>)
|
||||||
|
# speech removal (<text>?<resp+noise> => <noise>)
|
||||||
elif task == "ns" or task == "sr":
|
elif task == "ns" or task == "sr":
|
||||||
# sample random noise
|
# sample random noise
|
||||||
noise = self.sample_noise()
|
noise = self.sample_noise()
|
||||||
|
@ -886,40 +896,44 @@ class Dataset(_Dataset):
|
||||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||||
# create the input prompt by merging the target audio with the noise
|
# create the input prompt by merging the target audio with the noise
|
||||||
proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||||
|
|
||||||
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
|
if random.random() < 0.5:
|
||||||
|
text = None
|
||||||
|
|
||||||
|
# inject task token
|
||||||
|
proms = [
|
||||||
|
task,
|
||||||
|
proms
|
||||||
|
]
|
||||||
|
|
||||||
# set the target to just be the noise if <sr>
|
# set the target to just be the noise if <sr>
|
||||||
if task == "sr":
|
if task == "sr":
|
||||||
resps = noise
|
resps = noise
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
|
||||||
if random.random() < 0.5:
|
|
||||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
|
||||||
|
|
||||||
# target speech extraction
|
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
||||||
elif task == "tse":
|
elif task == "tse":
|
||||||
# sample a random, clean, utterance for the target speaker
|
# sample a prompt
|
||||||
clean_proms = self.sample_prompts(spkr_name, ignore=path)
|
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||||
# sample a random, clean utterance from a different speaker
|
|
||||||
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
# sample another speaker
|
||||||
|
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
|
||||||
|
|
||||||
# overlay the random speaker over the target audio
|
# overlay the random speaker over the target audio
|
||||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
other_resps = merge_audio( resps, other_resps, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||||
if other_proms.shape[0] == smallest_size:
|
|
||||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
if random.random() < 0.5:
|
||||||
else:
|
text = None
|
||||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
|
||||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
|
||||||
|
|
||||||
# stitch together the proms
|
# stitch together the proms
|
||||||
proms = [
|
proms = [
|
||||||
clean_proms,
|
proms,
|
||||||
task,
|
task,
|
||||||
noisy_proms,
|
other_resps,
|
||||||
]
|
]
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
|
||||||
if random.random() < 0.5:
|
|
||||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
|
||||||
|
|
||||||
# clean speech editing
|
# clean speech editing
|
||||||
elif task == "cse" or task == "nse":
|
elif task == "cse" or task == "nse":
|
||||||
|
@ -940,17 +954,21 @@ class Dataset(_Dataset):
|
||||||
pre_text = None
|
pre_text = None
|
||||||
pre_prom = None
|
pre_prom = None
|
||||||
# randomly drop out post
|
# randomly drop out post
|
||||||
if random.random() < 0.125:
|
elif random.random() < 0.125:
|
||||||
post_text = None
|
post_text = None
|
||||||
post_prom = None
|
post_prom = None
|
||||||
|
|
||||||
# create new text
|
# create new text
|
||||||
text = torch.cat(
|
text = concat_audio(
|
||||||
[ torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype) ] + # <s>
|
torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
|
||||||
([ pre_text, torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
|
pre_text,
|
||||||
[ edit_text ] + # 'edit text'
|
None if pre_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||||
([ torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
|
edit_text,
|
||||||
[ torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype) ] # </s>
|
None if post_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||||
|
post_text,
|
||||||
|
torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
|
||||||
|
|
||||||
|
reencode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if task == "nse":
|
if task == "nse":
|
||||||
|
@ -978,23 +996,26 @@ class Dataset(_Dataset):
|
||||||
# create new prom
|
# create new prom
|
||||||
proms = [
|
proms = [
|
||||||
pre_prom,
|
pre_prom,
|
||||||
"<soe>",
|
"soe",
|
||||||
"<mask>" if task == "cse" else mid_prom,
|
"mask" if task == "cse" else mid_prom,
|
||||||
"<eoe>",
|
"eoe",
|
||||||
post_prom,
|
post_prom,
|
||||||
]
|
]
|
||||||
|
|
||||||
# create new resp
|
# create new resp
|
||||||
resps = concat_audio(
|
resps = concat_audio(
|
||||||
*(([ pre_prom ] if pre_prom is not None else []) +
|
pre_prom,
|
||||||
[ edit_prom ] +
|
edit_prom,
|
||||||
([ post_prom ] if post_prom is not None else [])),
|
post_prom,
|
||||||
reencode=cfg.dataset.reencode_on_concat,
|
reencode=cfg.dataset.reencode_on_concat,
|
||||||
device=cfg.dataset.reencode_device,
|
device=cfg.dataset.reencode_device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Undefined task: {task}')
|
raise Exception(f'Undefined task: {task}')
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
path=Path(path),
|
path=Path(path),
|
||||||
|
|
|
@ -482,7 +482,9 @@ def repeat_extend_audio( qnt, target ):
|
||||||
# interleaves between a list of audios
|
# interleaves between a list of audios
|
||||||
# useful for interleaving silence
|
# useful for interleaving silence
|
||||||
def interleave_audio( *args, audio=None ):
|
def interleave_audio( *args, audio=None ):
|
||||||
qnts = [*args]
|
qnts = [ *args ]
|
||||||
|
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||||
|
|
||||||
if audio is None:
|
if audio is None:
|
||||||
return qnts
|
return qnts
|
||||||
|
|
||||||
|
@ -498,7 +500,8 @@ def interleave_audio( *args, audio=None ):
|
||||||
|
|
||||||
# concats two audios together
|
# concats two audios together
|
||||||
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
|
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
|
||||||
qnts = [*args]
|
qnts = [ *args ]
|
||||||
|
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||||
# just naively combine the codes
|
# just naively combine the codes
|
||||||
if not reencode:
|
if not reencode:
|
||||||
return torch.concat( qnts )
|
return torch.concat( qnts )
|
||||||
|
@ -510,9 +513,19 @@ def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_lev
|
||||||
# merges two quantized audios together
|
# merges two quantized audios together
|
||||||
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
|
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
|
||||||
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
|
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
|
||||||
qnts = [*args]
|
qnts = [ *args ]
|
||||||
|
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||||
|
|
||||||
|
# max length
|
||||||
|
max_length = max([ wav.shape[-1] for wav in decoded ])
|
||||||
|
for i, wav in enumerate(decoded):
|
||||||
|
delta = max_length - wav.shape[-1]
|
||||||
|
if delta <= 0:
|
||||||
|
continue
|
||||||
|
pad = torch.zeros( (1, delta), dtype=wav.dtype, device=wav.device )
|
||||||
|
decoded[i] = torch.cat( [ wav, pad ], dim=-1 )
|
||||||
|
|
||||||
# useful to adjust the volumes of each waveform
|
# useful to adjust the volumes of each waveform
|
||||||
if len(scale) == len(decoded):
|
if len(scale) == len(decoded):
|
||||||
for i in range(len(scale)):
|
for i in range(len(scale)):
|
||||||
|
|
|
@ -945,8 +945,8 @@ class Base(nn.Module):
|
||||||
):
|
):
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_embedding( input, quant_level ):
|
def prompt_input_to_embedding( input, quant_level ):
|
||||||
if isinstance(inputs, str):
|
if isinstance(input, str):
|
||||||
return self.tasks_emb( get_task_symmap()[f'<{input}>'] ) if self.tasks_emb is None else None
|
return self.tasks_emb( torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(device=device, dtype=torch.int16) )
|
||||||
|
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
if self.version <= 4:
|
if self.version <= 4:
|
||||||
|
@ -958,6 +958,7 @@ class Base(nn.Module):
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = []
|
batch = []
|
||||||
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
||||||
|
|
||||||
task_type = "tts"
|
task_type = "tts"
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
|
@ -971,13 +972,15 @@ class Base(nn.Module):
|
||||||
continue
|
continue
|
||||||
elif name == "text":
|
elif name == "text":
|
||||||
embedding = self.text_emb( input )
|
embedding = self.text_emb( input )
|
||||||
|
|
||||||
|
device = embedding.device
|
||||||
elif name == "quant_level" and self.rvq_l_emb is not None:
|
elif name == "quant_level" and self.rvq_l_emb is not None:
|
||||||
embedding = self.rvq_l_emb( input )
|
embedding = self.rvq_l_emb( input )
|
||||||
elif name == "lang" and self.langs_emb is not None:
|
elif name == "lang" and self.langs_emb is not None:
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms ] )
|
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
|
@ -1024,8 +1027,23 @@ class Base(nn.Module):
|
||||||
# there's a better way
|
# there's a better way
|
||||||
if not self.unified_position_ids:
|
if not self.unified_position_ids:
|
||||||
x_list = []
|
x_list = []
|
||||||
|
|
||||||
|
def get_input_token_length( name, input ):
|
||||||
|
# task token
|
||||||
|
if isinstance(input, str):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# list of tokens
|
||||||
|
if not isinstance(input, torch.Tensor):
|
||||||
|
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||||
|
|
||||||
|
return input.shape[0] + (0 if name == "resp" else 1)
|
||||||
|
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = torch.cat( [ torch.Tensor([*range( input.shape[0] + (0 if name == "resp" else 1) )]) for name, input in batch_input if name != "task" ] )
|
batch = torch.cat( [
|
||||||
|
torch.Tensor([*range(get_input_token_length(name, input))])
|
||||||
|
for name, input in batch_input if name != "task"
|
||||||
|
] )
|
||||||
|
|
||||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||||
if delta > 0:
|
if delta > 0:
|
||||||
|
@ -1044,10 +1062,12 @@ class Base(nn.Module):
|
||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
):
|
):
|
||||||
|
device = logits[0].device
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
if isinstance(inputs, str):
|
if isinstance(input, str):
|
||||||
return get_task_symmap()[f'<{input}>']
|
return torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
||||||
|
|
||||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||||
|
@ -1068,7 +1088,7 @@ class Base(nn.Module):
|
||||||
task_list.append( input )
|
task_list.append( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) )
|
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user