fixes for re-introducing SpeechX tasks (need to actually validate if these all do the right things)
This commit is contained in:
parent
bccbb77a1a
commit
83a0954f85
@ -165,7 +165,8 @@ class Dataset:
|
|||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||||
reencode_device: str = "cuda" # "cpu" is slower but saves memory
|
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||||
|
noise_scale: float = 0.25 # scaling noise value
|
||||||
|
|
||||||
_frames_per_second: int = 0 # allows setting your own hint
|
_frames_per_second: int = 0 # allows setting your own hint
|
||||||
|
|
||||||
|
|||||||
@ -544,6 +544,11 @@ class Dataset(_Dataset):
|
|||||||
# grab IDs for bos, space, and eos for easy input creation later
|
# grab IDs for bos, space, and eos for easy input creation later
|
||||||
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
||||||
|
|
||||||
|
# have it fetch at training time if any is invalid, because the tokenizer obj might not have it easily fetchable ahead of itme
|
||||||
|
# encoding before parallelizing things causes things to whine
|
||||||
|
if self.empty_text[0] is None or self.empty_text[-1] is None:
|
||||||
|
self.empty_text = None
|
||||||
|
|
||||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||||
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
||||||
|
|
||||||
@ -756,6 +761,9 @@ class Dataset(_Dataset):
|
|||||||
return prom
|
return prom
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
if self.empty_text is None:
|
||||||
|
self.empty_text = tokenize(" ")
|
||||||
|
|
||||||
bos_id, space_id, eos_id = self.empty_text
|
bos_id, space_id, eos_id = self.empty_text
|
||||||
|
|
||||||
if self.sampler_type == "group":
|
if self.sampler_type == "group":
|
||||||
@ -836,7 +844,10 @@ class Dataset(_Dataset):
|
|||||||
proms = proms[:, :cfg.model.resp_levels]
|
proms = proms[:, :cfg.model.resp_levels]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = "tts" # random.choice(self.tasks)
|
task = random.choice(self.tasks)
|
||||||
|
|
||||||
|
if f'<{task}>' not in self.task_symmap:
|
||||||
|
raise Exception(f'Task not defined: {task}')
|
||||||
|
|
||||||
# Base TTS (text + prompt => output)
|
# Base TTS (text + prompt => output)
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
@ -874,7 +885,7 @@ class Dataset(_Dataset):
|
|||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||||
# create the input prompt by merging the target audio with the noise
|
# create the input prompt by merging the target audio with the noise
|
||||||
proms = merge_audio( resps, noise, scale=[1, noise_scale], device=cfg.dataset.reencode_device )
|
proms = merge_audio( resps, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||||
# set the target to just be the noise if <sr>
|
# set the target to just be the noise if <sr>
|
||||||
if task == "sr":
|
if task == "sr":
|
||||||
resps = noise
|
resps = noise
|
||||||
@ -956,7 +967,7 @@ class Dataset(_Dataset):
|
|||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
n = repeat_extend_audio(noise, p.shape[0])
|
n = repeat_extend_audio(noise, p.shape[0])
|
||||||
# merge the noise over the utterance
|
# merge the noise over the utterance
|
||||||
return merge_audio(p, n, scale=[1, noise_scale], device=cfg.dataset.reencode_device)
|
return merge_audio(p, n, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device)
|
||||||
|
|
||||||
# apply noise to all pieces
|
# apply noise to all pieces
|
||||||
pre_prom = noise_proms( pre_prom )
|
pre_prom = noise_proms( pre_prom )
|
||||||
@ -975,9 +986,9 @@ class Dataset(_Dataset):
|
|||||||
|
|
||||||
# create new resp
|
# create new resp
|
||||||
resps = concat_audio(
|
resps = concat_audio(
|
||||||
([ pre_prom ] if pre_prom is not None else []) +
|
*(([ pre_prom ] if pre_prom is not None else []) +
|
||||||
[ edit_prom ] +
|
[ edit_prom ] +
|
||||||
([ post_prom ] if post_prom is not None else []),
|
([ post_prom ] if post_prom is not None else [])),
|
||||||
reencode=cfg.dataset.reencode_on_concat,
|
reencode=cfg.dataset.reencode_on_concat,
|
||||||
device=cfg.dataset.reencode_device,
|
device=cfg.dataset.reencode_device,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -880,7 +880,7 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
# Base-line TTS task
|
# Base-line TTS task
|
||||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||||
if task_type in ["tts", "tts-c", "ns", "sr"]:
|
if f'<{task_type}>' in get_task_symmap():
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None:
|
if text_list is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
@ -932,6 +932,9 @@ class Base(nn.Module):
|
|||||||
elif resps_list is not None:
|
elif resps_list is not None:
|
||||||
# yes this could be encoded better
|
# yes this could be encoded better
|
||||||
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
||||||
|
else:
|
||||||
|
|
||||||
|
raise Exception(f'Unrecognized task: {task_type}')
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@ -943,7 +946,7 @@ class Base(nn.Module):
|
|||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_embedding( input, quant_level ):
|
def prompt_input_to_embedding( input, quant_level ):
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
return self.tasks_emb( get_task_symmap( input ) ) if self.langs_emb is None else None
|
return self.tasks_emb( get_task_symmap()[f'<{input}>'] ) if self.tasks_emb is None else None
|
||||||
|
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
if self.version <= 4:
|
if self.version <= 4:
|
||||||
@ -1001,9 +1004,8 @@ class Base(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# should probably raise an exception so things aren't processed silently
|
# should probably raise an exception so things aren't processed silently
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch.append(embedding)
|
batch.append(embedding)
|
||||||
|
|
||||||
x_list.append( _join( batch, self.sep ) )
|
x_list.append( _join( batch, self.sep ) )
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
@ -1045,7 +1047,7 @@ class Base(nn.Module):
|
|||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
return get_task_symmap( input )
|
return get_task_symmap()[f'<{input}>']
|
||||||
|
|
||||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user