This commit is contained in:
mrq 2024-07-18 23:25:32 -05:00
parent 39f961abcd
commit 28a674e0f1
3 changed files with 105 additions and 51 deletions

View File

@ -275,6 +275,9 @@ def get_task_symmap():
"<soe>": 5, "<soe>": 5,
"<mask>": 6, "<mask>": 6,
"<eoe>": 7, "<eoe>": 7,
"<nse>": 6, # fake
"<cse>": 6, # fake
} }
def _replace_file_extension(path, suffix): def _replace_file_extension(path, suffix):
@ -849,12 +852,12 @@ class Dataset(_Dataset):
if f'<{task}>' not in self.task_symmap: if f'<{task}>' not in self.task_symmap:
raise Exception(f'Task not defined: {task}') raise Exception(f'Task not defined: {task}')
# Base TTS (text + prompt => output) # Base TTS (<text><prompt> => <resp>)
if task == "tts": if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps proms = self.sample_prompts(spkr_name, ignore=path)
# VALL-E Continuous (text + partial output => rest of output) # VALL-E Continuous (<text><partial resp> => <remaining resp> )
# (this could just be sampled as <text a><text b> + <audio a> => <audio b>, but I need to experiment with it) # (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
elif task == "tts-c": elif task == "tts-c":
# trim a piece of the output response # trim a piece of the output response
if naive: if naive:
@ -871,14 +874,21 @@ class Dataset(_Dataset):
# <s>[original text] [new text]</s> # <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s> # removes the original text's </s>, includes a space, and remove the new text's <s>
else: else:
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ]) text = torch.concat([ text[:-1], torch.tensor([space_id]).to(torch.int16), txt[1:] ])
# set prompt as initial response # set prompt as initial response
proms = resps proms = resps
# set target as newly sampled response # set target as newly sampled response
resps = qnt resps = qnt
# noise suppression || speech removal # inject task token
proms = [
proms,
task,
]
# noise suppression (<text>? <resp+noise> => <resp>)
# speech removal (<text>?<resp+noise> => <noise>)
elif task == "ns" or task == "sr": elif task == "ns" or task == "sr":
# sample random noise # sample random noise
noise = self.sample_noise() noise = self.sample_noise()
@ -886,40 +896,44 @@ class Dataset(_Dataset):
noise = repeat_extend_audio(noise, resps.shape[0]) noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise # create the input prompt by merging the target audio with the noise
proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = None
# inject task token
proms = [
task,
proms
]
# set the target to just be the noise if <sr> # set the target to just be the noise if <sr>
if task == "sr": if task == "sr":
resps = noise resps = noise
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# target speech extraction # target speech extraction ( <text><prom><resp + other resp> => <resp> )
elif task == "tse": elif task == "tse":
# sample a random, clean, utterance for the target speaker # sample a prompt
clean_proms = self.sample_prompts(spkr_name, ignore=path) proms = self.sample_prompts(spkr_name, ignore=path)
# sample a random, clean utterance from a different speaker
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="") # sample another speaker
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
# overlay the random speaker over the target audio # overlay the random speaker over the target audio
smallest_size = min(resps.shape[0], other_proms.shape[0]) other_resps = merge_audio( resps, other_resps, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device ) # set the text prompt to empty to train without a guided text prompt
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] ) if random.random() < 0.5:
else: text = None
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the proms # stitch together the proms
proms = [ proms = [
clean_proms, proms,
task, task,
noisy_proms, other_resps,
] ]
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# clean speech editing # clean speech editing
elif task == "cse" or task == "nse": elif task == "cse" or task == "nse":
@ -940,17 +954,21 @@ class Dataset(_Dataset):
pre_text = None pre_text = None
pre_prom = None pre_prom = None
# randomly drop out post # randomly drop out post
if random.random() < 0.125: elif random.random() < 0.125:
post_text = None post_text = None
post_prom = None post_prom = None
# create new text # create new text
text = torch.cat( text = concat_audio(
[ torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype) ] + # <s> torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
([ pre_text, torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space' pre_text,
[ edit_text ] + # 'edit text' None if pre_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
([ torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text edit_text,
[ torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype) ] # </s> None if post_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
post_text,
torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
reencode=False,
) )
if task == "nse": if task == "nse":
@ -978,23 +996,26 @@ class Dataset(_Dataset):
# create new prom # create new prom
proms = [ proms = [
pre_prom, pre_prom,
"<soe>", "soe",
"<mask>" if task == "cse" else mid_prom, "mask" if task == "cse" else mid_prom,
"<eoe>", "eoe",
post_prom, post_prom,
] ]
# create new resp # create new resp
resps = concat_audio( resps = concat_audio(
*(([ pre_prom ] if pre_prom is not None else []) + pre_prom,
[ edit_prom ] + edit_prom,
([ post_prom ] if post_prom is not None else [])), post_prom,
reencode=cfg.dataset.reencode_on_concat, reencode=cfg.dataset.reencode_on_concat,
device=cfg.dataset.reencode_device, device=cfg.dataset.reencode_device,
) )
else: else:
raise Exception(f'Undefined task: {task}') raise Exception(f'Undefined task: {task}')
if text is None:
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
return dict( return dict(
index=index, index=index,
path=Path(path), path=Path(path),

View File

@ -482,7 +482,9 @@ def repeat_extend_audio( qnt, target ):
# interleaves between a list of audios # interleaves between a list of audios
# useful for interleaving silence # useful for interleaving silence
def interleave_audio( *args, audio=None ): def interleave_audio( *args, audio=None ):
qnts = [*args] qnts = [ *args ]
qnts = [ qnt for qnt in qnts if qnt is not None ]
if audio is None: if audio is None:
return qnts return qnts
@ -498,7 +500,8 @@ def interleave_audio( *args, audio=None ):
# concats two audios together # concats two audios together
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ): def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
qnts = [*args] qnts = [ *args ]
qnts = [ qnt for qnt in qnts if qnt is not None ]
# just naively combine the codes # just naively combine the codes
if not reencode: if not reencode:
return torch.concat( qnts ) return torch.concat( qnts )
@ -510,9 +513,19 @@ def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_lev
# merges two quantized audios together # merges two quantized audios together
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic # requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ): def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
qnts = [*args] qnts = [ *args ]
qnts = [ qnt for qnt in qnts if qnt is not None ]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
# max length
max_length = max([ wav.shape[-1] for wav in decoded ])
for i, wav in enumerate(decoded):
delta = max_length - wav.shape[-1]
if delta <= 0:
continue
pad = torch.zeros( (1, delta), dtype=wav.dtype, device=wav.device )
decoded[i] = torch.cat( [ wav, pad ], dim=-1 )
# useful to adjust the volumes of each waveform # useful to adjust the volumes of each waveform
if len(scale) == len(decoded): if len(scale) == len(decoded):
for i in range(len(scale)): for i in range(len(scale)):

View File

@ -945,8 +945,8 @@ class Base(nn.Module):
): ):
# handles tasks where the prompt has task tokens injected in the middle # handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding( input, quant_level ): def prompt_input_to_embedding( input, quant_level ):
if isinstance(inputs, str): if isinstance(input, str):
return self.tasks_emb( get_task_symmap()[f'<{input}>'] ) if self.tasks_emb is None else None return self.tasks_emb( torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(device=device, dtype=torch.int16) )
# get RVQ level 0, or up to targetted RVQ level inference # get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4: if self.version <= 4:
@ -958,6 +958,7 @@ class Base(nn.Module):
for batch_index, batch_input in enumerate(inputs): for batch_index, batch_input in enumerate(inputs):
batch = [] batch = []
quant_level = quant_levels[batch_index] if quant_levels is not None else 0 quant_level = quant_levels[batch_index] if quant_levels is not None else 0
task_type = "tts" task_type = "tts"
for name, input in batch_input: for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing # technically can provide a map for input_name => embedding, but some embedding requires additional processing
@ -971,13 +972,15 @@ class Base(nn.Module):
continue continue
elif name == "text": elif name == "text":
embedding = self.text_emb( input ) embedding = self.text_emb( input )
device = embedding.device
elif name == "quant_level" and self.rvq_l_emb is not None: elif name == "quant_level" and self.rvq_l_emb is not None:
embedding = self.rvq_l_emb( input ) embedding = self.rvq_l_emb( input )
elif name == "lang" and self.langs_emb is not None: elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input ) embedding = self.langs_emb( input )
elif name == "prom": elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input proms = [ input ] if isinstance(input, torch.Tensor) else input
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms ] ) embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
elif name == "tone" and self.tones_emb is not None: elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input ) embedding = self.tones_emb( input )
elif name == "resp": elif name == "resp":
@ -1024,8 +1027,23 @@ class Base(nn.Module):
# there's a better way # there's a better way
if not self.unified_position_ids: if not self.unified_position_ids:
x_list = [] x_list = []
def get_input_token_length( name, input ):
# task token
if isinstance(input, str):
return 1
# list of tokens
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
return input.shape[0] + (0 if name == "resp" else 1)
for batch_index, batch_input in enumerate(inputs): for batch_index, batch_input in enumerate(inputs):
batch = torch.cat( [ torch.Tensor([*range( input.shape[0] + (0 if name == "resp" else 1) )]) for name, input in batch_input if name != "task" ] ) batch = torch.cat( [
torch.Tensor([*range(get_input_token_length(name, input))])
for name, input in batch_input if name != "task"
] )
delta = ids[batch_index].shape[0] - batch.shape[0] delta = ids[batch_index].shape[0] - batch.shape[0]
if delta > 0: if delta > 0:
@ -1044,10 +1062,12 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
): ):
device = logits[0].device
# handles tasks where the prompt has task tokens injected in the middle # handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ): def prompt_input_to_token( input, quant_level ):
if isinstance(inputs, str): if isinstance(input, str):
return get_task_symmap()[f'<{input}>'] return torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
@ -1068,7 +1088,7 @@ class Base(nn.Module):
task_list.append( input ) task_list.append( input )
elif name == "prom": elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input proms = [ input ] if isinstance(input, torch.Tensor) else input
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) ) target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
elif name == "resp": elif name == "resp":
target.append( input if input.dim() == 1 else input[:, quant_level] ) target.append( input if input.dim() == 1 else input[:, quant_level] )
elif name in ["text", "quant_level", "lang", "tone", "len"]: elif name in ["text", "quant_level", "lang", "tone", "len"]: