fixes for re-introducing SpeechX tasks (need to actually validate if these all do the right things)

This commit is contained in:
mrq 2024-07-18 17:16:32 -05:00
parent bccbb77a1a
commit 83a0954f85
3 changed files with 25 additions and 11 deletions

View File

@ -165,7 +165,8 @@ class Dataset:
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
reencode_device: str = "cuda" # "cpu" is slower but saves memory
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
noise_scale: float = 0.25 # scaling noise value
_frames_per_second: int = 0 # allows setting your own hint

View File

@ -544,6 +544,11 @@ class Dataset(_Dataset):
# grab IDs for bos, space, and eos for easy input creation later
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
# have it fetch at training time if any is invalid, because the tokenizer obj might not have it easily fetchable ahead of itme
# encoding before parallelizing things causes things to whine
if self.empty_text[0] is None or self.empty_text[-1] is None:
self.empty_text = None
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
@ -756,6 +761,9 @@ class Dataset(_Dataset):
return prom
def __getitem__(self, index):
if self.empty_text is None:
self.empty_text = tokenize(" ")
bos_id, space_id, eos_id = self.empty_text
if self.sampler_type == "group":
@ -836,7 +844,10 @@ class Dataset(_Dataset):
proms = proms[:, :cfg.model.resp_levels]
"""
task = "tts" # random.choice(self.tasks)
task = random.choice(self.tasks)
if f'<{task}>' not in self.task_symmap:
raise Exception(f'Task not defined: {task}')
# Base TTS (text + prompt => output)
if task == "tts":
@ -874,7 +885,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio( resps, noise, scale=[1, noise_scale], device=cfg.dataset.reencode_device )
proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
@ -956,7 +967,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio
n = repeat_extend_audio(noise, p.shape[0])
# merge the noise over the utterance
return merge_audio(p, n, scale=[1, noise_scale], device=cfg.dataset.reencode_device)
return merge_audio(p, n, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device)
# apply noise to all pieces
pre_prom = noise_proms( pre_prom )
@ -975,9 +986,9 @@ class Dataset(_Dataset):
# create new resp
resps = concat_audio(
([ pre_prom ] if pre_prom is not None else []) +
*(([ pre_prom ] if pre_prom is not None else []) +
[ edit_prom ] +
([ post_prom ] if post_prom is not None else []),
([ post_prom ] if post_prom is not None else [])),
reencode=cfg.dataset.reencode_on_concat,
device=cfg.dataset.reencode_device,
)

View File

@ -880,7 +880,7 @@ class Base(nn.Module):
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
if task_type in ["tts", "tts-c", "ns", "sr"]:
if f'<{task_type}>' in get_task_symmap():
# insert the text prompt
if text_list is not None:
inputs[i].append( ( "text", text_list[i] ) )
@ -932,6 +932,9 @@ class Base(nn.Module):
elif resps_list is not None:
# yes this could be encoded better
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
else:
raise Exception(f'Unrecognized task: {task_type}')
return inputs
@ -943,7 +946,7 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding( input, quant_level ):
if isinstance(inputs, str):
return self.tasks_emb( get_task_symmap( input ) ) if self.langs_emb is None else None
return self.tasks_emb( get_task_symmap()[f'<{input}>'] ) if self.tasks_emb is None else None
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
@ -1001,9 +1004,8 @@ class Base(nn.Module):
else:
# should probably raise an exception so things aren't processed silently
continue
batch.append(embedding)
x_list.append( _join( batch, self.sep ) )
return x_list
@ -1045,7 +1047,7 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
if isinstance(inputs, str):
return get_task_symmap( input )
return get_task_symmap()[f'<{input}>']
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):