fixes for re-introducing SpeechX tasks (need to actually validate if these all do the right things)
This commit is contained in:
parent
bccbb77a1a
commit
83a0954f85
|
@ -165,7 +165,8 @@ class Dataset:
|
|||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||
reencode_device: str = "cuda" # "cpu" is slower but saves memory
|
||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||
noise_scale: float = 0.25 # scaling noise value
|
||||
|
||||
_frames_per_second: int = 0 # allows setting your own hint
|
||||
|
||||
|
|
|
@ -544,6 +544,11 @@ class Dataset(_Dataset):
|
|||
# grab IDs for bos, space, and eos for easy input creation later
|
||||
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
||||
|
||||
# have it fetch at training time if any is invalid, because the tokenizer obj might not have it easily fetchable ahead of itme
|
||||
# encoding before parallelizing things causes things to whine
|
||||
if self.empty_text[0] is None or self.empty_text[-1] is None:
|
||||
self.empty_text = None
|
||||
|
||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
||||
|
||||
|
@ -756,6 +761,9 @@ class Dataset(_Dataset):
|
|||
return prom
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.empty_text is None:
|
||||
self.empty_text = tokenize(" ")
|
||||
|
||||
bos_id, space_id, eos_id = self.empty_text
|
||||
|
||||
if self.sampler_type == "group":
|
||||
|
@ -836,7 +844,10 @@ class Dataset(_Dataset):
|
|||
proms = proms[:, :cfg.model.resp_levels]
|
||||
"""
|
||||
|
||||
task = "tts" # random.choice(self.tasks)
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
if f'<{task}>' not in self.task_symmap:
|
||||
raise Exception(f'Task not defined: {task}')
|
||||
|
||||
# Base TTS (text + prompt => output)
|
||||
if task == "tts":
|
||||
|
@ -874,7 +885,7 @@ class Dataset(_Dataset):
|
|||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio( resps, noise, scale=[1, noise_scale], device=cfg.dataset.reencode_device )
|
||||
proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
|
@ -956,7 +967,7 @@ class Dataset(_Dataset):
|
|||
# extend the noise to fill the target audio
|
||||
n = repeat_extend_audio(noise, p.shape[0])
|
||||
# merge the noise over the utterance
|
||||
return merge_audio(p, n, scale=[1, noise_scale], device=cfg.dataset.reencode_device)
|
||||
return merge_audio(p, n, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device)
|
||||
|
||||
# apply noise to all pieces
|
||||
pre_prom = noise_proms( pre_prom )
|
||||
|
@ -975,9 +986,9 @@ class Dataset(_Dataset):
|
|||
|
||||
# create new resp
|
||||
resps = concat_audio(
|
||||
([ pre_prom ] if pre_prom is not None else []) +
|
||||
*(([ pre_prom ] if pre_prom is not None else []) +
|
||||
[ edit_prom ] +
|
||||
([ post_prom ] if post_prom is not None else []),
|
||||
([ post_prom ] if post_prom is not None else [])),
|
||||
reencode=cfg.dataset.reencode_on_concat,
|
||||
device=cfg.dataset.reencode_device,
|
||||
)
|
||||
|
|
|
@ -880,7 +880,7 @@ class Base(nn.Module):
|
|||
|
||||
# Base-line TTS task
|
||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||
if task_type in ["tts", "tts-c", "ns", "sr"]:
|
||||
if f'<{task_type}>' in get_task_symmap():
|
||||
# insert the text prompt
|
||||
if text_list is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
@ -932,6 +932,9 @@ class Base(nn.Module):
|
|||
elif resps_list is not None:
|
||||
# yes this could be encoded better
|
||||
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
||||
else:
|
||||
|
||||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
|
||||
return inputs
|
||||
|
||||
|
@ -943,7 +946,7 @@ class Base(nn.Module):
|
|||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_embedding( input, quant_level ):
|
||||
if isinstance(inputs, str):
|
||||
return self.tasks_emb( get_task_symmap( input ) ) if self.langs_emb is None else None
|
||||
return self.tasks_emb( get_task_symmap()[f'<{input}>'] ) if self.tasks_emb is None else None
|
||||
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
if self.version <= 4:
|
||||
|
@ -1001,9 +1004,8 @@ class Base(nn.Module):
|
|||
else:
|
||||
# should probably raise an exception so things aren't processed silently
|
||||
continue
|
||||
|
||||
batch.append(embedding)
|
||||
|
||||
|
||||
x_list.append( _join( batch, self.sep ) )
|
||||
|
||||
return x_list
|
||||
|
@ -1045,7 +1047,7 @@ class Base(nn.Module):
|
|||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
if isinstance(inputs, str):
|
||||
return get_task_symmap( input )
|
||||
return get_task_symmap()[f'<{input}>']
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
|
|
Loading…
Reference in New Issue
Block a user