re-introducing SpeechX tasks (need to validate them all, everything works with base tts anyways)

This commit is contained in:
mrq 2024-07-18 16:16:14 -05:00
parent c2b8035e74
commit 97e768601c
6 changed files with 208 additions and 149 deletions

View File

@ -871,6 +871,14 @@ class NaiveTokenizer:
"""
return {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, '': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, '': 220, 'eˈ': 221, 'ʍ': 222, '': 223, '': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
@cached_property
def _bos_token( self ):
return self.get_vocab()["<s>"]
@cached_property
def _eos_token( self ):
return self.get_vocab()["</s>"]
def encode( self, s ):
symmap = self.get_vocab()
phones = " ".join( list(s) )

View File

@ -541,6 +541,14 @@ class Dataset(_Dataset):
self.tone_symmap = self._get_tone_symmap()
self.task_symmap = self._get_task_symmap()
"""
self.empty_text = tokenize(" ")
if len(self.empty_text) == 4:
self.empty_text = self.empty_text[:1] + self.empty_text[1:2] + self.empty_text[-1:]
"""
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
@ -665,7 +673,45 @@ class Dataset(_Dataset):
choices = set(self.spkrs) - set(ignore)
return random.choice([*choices])
def sample_prompts(self, spkr_name, ignore):
def sample_utterance(self, spkr_name, ignore=[]):
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - set(ignore))]
if len(choices) == 0:
return None, None, None
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
metadata = cfg.hdf5[key].attrs()
text = cfg.hdf5[key]["text"][:]
resps = cfg.hdf5[key]["audio"][:, :]
text = torch.from_numpy(text).to(self.text_dtype)
resps = torch.from_numpy(resps).to(torch.int16)
"""
lang = metadata["language"] if "language" in metadata else None
tone = metadata["tone"] if "tone" in metadata else None
"""
else:
resps, metadata = _load_quants(path, return_metadata=True)
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
"""
lang = metadata["language"] if "language" in metadata else None
tone = metadata["tone"] if "tone" in metadata else None
"""
return path, text, resps
def sample_prompts(self, spkr_name, ignore, should_trim=True):
prom_list = []
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
@ -681,7 +727,7 @@ class Dataset(_Dataset):
"""
prom_length = 0
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
@ -715,6 +761,8 @@ class Dataset(_Dataset):
return prom
def __getitem__(self, index):
bos_id, space_id, eos_id = self.empty_text
if self.sampler_type == "group":
spkr_group = self.spkr_groups[index]
#spkr_group_id = self.spkr_group_symmap[spkr_group]
@ -740,82 +788,99 @@ class Dataset(_Dataset):
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
metadata = cfg.hdf5[key].attrs()
text = cfg.hdf5[key]["text"][:]
resps = cfg.hdf5[key]["audio"][:, :]
text = torch.from_numpy(text).to(self.text_dtype)
resps = torch.from_numpy(resps).to(torch.int16)
lang = metadata["language"] if "language" in metadata else None
tone = metadata["tone"] if "tone" in metadata else None
else:
resps, metadata = _load_quants(path, return_metadata=True)
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
lang = metadata["language"] if "language" in metadata else None
tone = metadata["tone"] if "tone" in metadata else None
if not lang:
lang = self.get_language(spkr_group)
if not tone:
tone = "neutral"
lang = torch.tensor([self.lang_symmap[lang]]).to(torch.uint8)
tone = torch.tensor([self.tone_symmap[tone]]).to(torch.uint8)
naive = True
# append additional prompts in an attempt to artifically increase lengths / offer new data
"""
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
# it probably is better to pad with silence instead of just stitching utterances and ruining things
"""
# append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
ignore_paths = []
for _ in range( cfg.dataset.max_resps - 1 ):
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
ignore_paths.append(path)
if len(choices) > 0:
for _ in range( cfg.dataset.max_resps - 1 ):
sampled_path = random.choice(choices)
choices = [*(set(choices) - {sampled_path})]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(sampled_path)
txt = cfg.hdf5[key]["text"][:]
qnt = cfg.hdf5[key]["audio"][:, :]
txt = np.array( txt )
txt = torch.from_numpy(txt).to(self.text_dtype)
qnt = torch.from_numpy(qnt).to(torch.int16)
else:
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
#txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
qnt, metadata = _load_quants(sampled_path, return_metadata=True)
txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
# <s>[original text]</s><s>[new text]</s>
if naive:
text = torch.concat([ text, txt ])
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
else:
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = torch.concat([ resps, qnt ])
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = torch.concat([ resps, qnt ])
"""
"""
task = "tts"
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# Disabled until I swap over to a better method
"""
task = random.choice(self.tasks)
"""
resps = resps[:, :cfg.model.resp_levels]
proms = proms[:, :cfg.model.resp_levels]
"""
task = "tts" # random.choice(self.tasks)
# ensure a speaker has at least four utterances
# default to tts if not
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
task = "tts"
noise_scale = 0.25
if task == "tts" or task == "tts-c":
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
# demote if the target is too short
if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
task = "tts"
# Base TTS (text + prompt => output)
if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# VALL-E Continuous (text + partial output => rest of output)
# (this could just be sampled as <text a><text b> + <audio a> => <audio b>, but I need to experiment with it)
elif task == "tts-c":
# trim a piece of the output response
if naive:
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
# VALL-E continuous
# ignore if target utterance is shorter than prompt duration
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
if task == "tts-c":
proms = resps[:trim_length, :]
resps = resps[trim_length:, :]
proms = torch.cat( [self.get_task_token(task), proms] )
else:
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
path, txt, qnt = self.sample_utterance(spkr_name)
# <s>[original text]</s><s>[new text]</s>
if naive:
text = torch.concat([ text, txt ])
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
else:
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
# set prompt as initial response
proms = resps
# set target as newly sampled response
resps = qnt
# noise suppression || speech removal
elif task == "ns" or task == "sr":
# sample random noise
@ -827,20 +892,19 @@ class Dataset(_Dataset):
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
# prepend the task token
proms = torch.cat( [self.get_task_token(task), proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(self.text_dtype)
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# target speech extraction
elif task == "tse":
# sample a random, clean, utterance for the target speaker
clean_proms = self.sample_prompts(spkr_name, ignore=path)
# sample a random, clean utterance from a different speaker
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
# overlay the random speaker over the target audio
smallest_size = min(resps.shape[0], other_proms.shape[0])
if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
@ -849,33 +913,30 @@ class Dataset(_Dataset):
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the promps
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
# stitch together the proms
proms = [
clean_proms,
task,
noisy_proms,
]
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(self.text_dtype) # <s></s>
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
# clean speech editing
elif task == "cse" or task == "nse":
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
sampled = random.sample([*choices], 4)
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
# instead we'll just sample a bunch of utterances
if cfg.dataset.use_hdf5:
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :]).to(torch.int16) for path in sampled ]
else:
texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ]
qnts = [ _load_quants(path) for path in sampled ]
samples = []
for _ in range( 4 ):
sampled = self.sample_utterance(spkr_name, ignore=[s[0] for s in samples])
samples.append( sampled )
# remove <s></s>
for i in range(len(texts)):
texts[i] = texts[i][1:-1]
pre_text, mid_text, post_text, edit_text = texts
pre_prom, mid_prom, post_prom, edit_prom = qnts
pre_text, mid_text, post_text, edit_text = [ s[1][1:-1] for s in samples ]
pre_prom, mid_prom, post_prom, edit_prom = [ s[2] for s in samples ]
# randomly drop out pre
if random.random() < 0.125:
@ -888,11 +949,11 @@ class Dataset(_Dataset):
# create new text
text = torch.cat(
[ torch.Tensor( [ 1 ] ).to(dtype=self.text_dtype) ] + # <s>
([ pre_text, torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
[ torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype) ] + # <s>
([ pre_text, torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
[ edit_text ] + # 'edit text'
([ torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
[ torch.Tensor( [ 2 ] ).to(dtype=self.text_dtype) ] # </s>
([ torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
[ torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype) ] # </s>
)
if task == "nse":
@ -916,17 +977,16 @@ class Dataset(_Dataset):
mid_prom = noise_proms( mid_prom )
post_prom = noise_proms( post_prom )
edit_prom = noise_proms( edit_prom )
else:
mid_prom = self.get_task_token("mask")
# create new proms
proms = torch.cat(
([ pre_prom ] if pre_prom is not None else []) +
[self.get_task_token("soe")] +
[ mid_prom ] + # is <mask> if task is CSE
[self.get_task_token("eoe")] +
([ post_prom ] if post_prom is not None else [])
)
# create new prom
proms = [
pre_prom,
"<soe>",
"<mask>" if task == "cse" else mid_prom,
"<eoe>",
post_prom,
]
# create new resp
resps = torch.cat(
([ pre_prom ] if pre_prom is not None else []) +
@ -935,35 +995,6 @@ class Dataset(_Dataset):
)
else:
raise Exception(f'Undefined task: {task}')
"""
"""
# emulate SVC
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
elif task == "svc":
# sample a random, clean utterance for the target speaker
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# sample a reference clip from a different speaker
ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name]))
#
resps =
# stitch together the promps
proms = torch.cat( [proms, self.get_task_token(task), ref_proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(self.text_dtype)
"""
# trim to fit to requested prom/resps levels
proms = proms[:, :cfg.model.resp_levels]
resps = resps[:, :cfg.model.resp_levels]
return dict(
index=index,
@ -972,6 +1003,7 @@ class Dataset(_Dataset):
spkr_id=spkr_id,
task=task,
lang=lang,
tone=tone,
text=text,
proms=proms,
resps=resps,

View File

@ -94,6 +94,7 @@ class AR_NAR(Base):
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
@ -116,6 +117,10 @@ class AR_NAR(Base):
):
device = text_list[0].device
batch_size = len(text_list)
# generate task list if not provided
if task_list is None:
task_list = [ "tts" for _ in range(batch_size) ]
# is training or NAR
if resps_list is not None:
@ -127,15 +132,8 @@ class AR_NAR(Base):
# is training
if training:
# to-do: make this YAML configurable
def sample_task():
return "tts"
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
# generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
@ -164,12 +162,14 @@ class AR_NAR(Base):
"""
for i in range(batch_size):
# cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1
# other tasks might have the prom be a list and this is just the easiest way to acknowledge that
if task_list[i] == "tts":
# cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1
if quant_levels[i] >= proms_list[i].shape[-1]:
quant_levels[i] = proms_list[i].shape[-1] - 1
if quant_levels[i] >= proms_list[i].shape[-1]:
quant_levels[i] = proms_list[i].shape[-1] - 1
# only apply stop token for RVQ level 0
if quant_levels[i] > 0:
@ -261,7 +261,7 @@ class AR_NAR(Base):
stopped = torch.zeros(batch_size, device=device).bool()
stop_token = self.stop_token
task_list = [ "tts" for _ in range(batch_size) ]
state = None
mirostat = [

View File

@ -33,6 +33,9 @@ from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_to
from ..emb.qnt import encode_as_embedding
# yuck, kind of needed
from ..data import get_task_symmap
"""
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
"""
@ -877,7 +880,7 @@ class Base(nn.Module):
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
if task_type == "tts":
if task_type in ["tts", "tts-c", "ns", "sr"]:
# insert the text prompt
if text_list is not None:
inputs[i].append( ( "text", text_list[i] ) )
@ -915,7 +918,6 @@ class Base(nn.Module):
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
quant_levels[i] = 0
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
# insert input audio prompt
if proms_list is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
@ -938,10 +940,22 @@ class Base(nn.Module):
inputs: list,
quant_levels: int | list[int] | Tensor | None = None
):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding( input, quant_level ):
if isinstance(inputs, str):
return self.tasks_emb( get_task_symmap( input ) ) if self.langs_emb is None else None
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
return self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
x_list = []
for batch_index, batch_input in enumerate(inputs):
batch = []
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
task_type = "tts"
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
embedding = None
@ -950,7 +964,7 @@ class Base(nn.Module):
if name == "task":
# noop
# *maybe* inject a token for specifying task type
...
task_type = input
continue
elif name == "text":
embedding = self.text_emb( input )
@ -959,14 +973,8 @@ class Base(nn.Module):
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
embedding = self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
else:
if quant_level == 0:
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :1], offset = 0 )
else:
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :quant_level], offset = 0 )
proms = [ input ] if isinstance(input, torch.Tensor) else input
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms ] )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
@ -1034,6 +1042,17 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None,
):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
if isinstance(inputs, str):
return get_task_symmap( input )
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
return torch.full_like(input[..., 0], self.ignore_index)
return input if input.dim() == 1 else input[:, quant_level]
# old, "naive" way, no loss factoring
if not self.config.loss_factors:
target_list = []
@ -1046,12 +1065,8 @@ class Base(nn.Module):
if name == "task":
task_list.append( input )
elif name == "prom":
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
target.append( torch.full_like(input[..., 0], self.ignore_index) )
# we *CAN* directly map to proms
else:
target.append( input if input.dim() == 1 else input[:, quant_level] )
proms = [ input ] if isinstance(input, torch.Tensor) else input
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) )
elif name == "resp":
target.append( input if input.dim() == 1 else input[:, quant_level] )
elif name in ["text", "quant_level", "lang", "tone", "len"]:
@ -1119,7 +1134,8 @@ class Base(nn.Module):
input = input if input.dim() == 1 else input[:, quant_level]
# select prom level
elif name == "prom":
input = input[:, quant_level]
proms = [ input ] if isinstance(input, torch.Tensor) else input
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
# meta-input, no corresponding token at the moment
elif name == "task":
continue

View File

@ -92,6 +92,7 @@ class NAR(Base):
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,

View File

@ -64,6 +64,8 @@ def train_feeder(engine, batch):
proms_list=batch["proms"],
resps_list=batch["resps"],
lang_list=batch["lang"],
tone_list=batch["tone"],
task_list=batch["task"],
training=True,
)