re-introducing SpeechX tasks (need to validate them all, everything works with base tts anyways)
This commit is contained in:
parent
c2b8035e74
commit
97e768601c
|
@ -871,6 +871,14 @@ class NaiveTokenizer:
|
||||||
"""
|
"""
|
||||||
return {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, 'ᵝ': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, 'oˌ': 220, 'eˈ': 221, 'ʍ': 222, 'eˌ': 223, 'uˌ': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
|
return {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, 'ᵝ': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, 'oˌ': 220, 'eˈ': 221, 'ʍ': 222, 'eˌ': 223, 'uˌ': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _bos_token( self ):
|
||||||
|
return self.get_vocab()["<s>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _eos_token( self ):
|
||||||
|
return self.get_vocab()["</s>"]
|
||||||
|
|
||||||
def encode( self, s ):
|
def encode( self, s ):
|
||||||
symmap = self.get_vocab()
|
symmap = self.get_vocab()
|
||||||
phones = " ".join( list(s) )
|
phones = " ".join( list(s) )
|
||||||
|
|
254
vall_e/data.py
254
vall_e/data.py
|
@ -541,6 +541,14 @@ class Dataset(_Dataset):
|
||||||
self.tone_symmap = self._get_tone_symmap()
|
self.tone_symmap = self._get_tone_symmap()
|
||||||
self.task_symmap = self._get_task_symmap()
|
self.task_symmap = self._get_task_symmap()
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.empty_text = tokenize(" ")
|
||||||
|
if len(self.empty_text) == 4:
|
||||||
|
self.empty_text = self.empty_text[:1] + self.empty_text[1:2] + self.empty_text[-1:]
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
||||||
|
|
||||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||||
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
||||||
|
|
||||||
|
@ -665,7 +673,45 @@ class Dataset(_Dataset):
|
||||||
choices = set(self.spkrs) - set(ignore)
|
choices = set(self.spkrs) - set(ignore)
|
||||||
return random.choice([*choices])
|
return random.choice([*choices])
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name, ignore):
|
def sample_utterance(self, spkr_name, ignore=[]):
|
||||||
|
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - set(ignore))]
|
||||||
|
|
||||||
|
if len(choices) == 0:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
path = random.choice(choices)
|
||||||
|
|
||||||
|
if cfg.dataset.use_hdf5:
|
||||||
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
|
if key not in cfg.hdf5:
|
||||||
|
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||||
|
|
||||||
|
metadata = cfg.hdf5[key].attrs()
|
||||||
|
|
||||||
|
text = cfg.hdf5[key]["text"][:]
|
||||||
|
resps = cfg.hdf5[key]["audio"][:, :]
|
||||||
|
|
||||||
|
text = torch.from_numpy(text).to(self.text_dtype)
|
||||||
|
resps = torch.from_numpy(resps).to(torch.int16)
|
||||||
|
|
||||||
|
"""
|
||||||
|
lang = metadata["language"] if "language" in metadata else None
|
||||||
|
tone = metadata["tone"] if "tone" in metadata else None
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
resps, metadata = _load_quants(path, return_metadata=True)
|
||||||
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
|
|
||||||
|
"""
|
||||||
|
lang = metadata["language"] if "language" in metadata else None
|
||||||
|
tone = metadata["tone"] if "tone" in metadata else None
|
||||||
|
"""
|
||||||
|
|
||||||
|
return path, text, resps
|
||||||
|
|
||||||
|
|
||||||
|
def sample_prompts(self, spkr_name, ignore, should_trim=True):
|
||||||
prom_list = []
|
prom_list = []
|
||||||
|
|
||||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
||||||
|
@ -681,7 +727,7 @@ class Dataset(_Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
|
@ -715,6 +761,8 @@ class Dataset(_Dataset):
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
bos_id, space_id, eos_id = self.empty_text
|
||||||
|
|
||||||
if self.sampler_type == "group":
|
if self.sampler_type == "group":
|
||||||
spkr_group = self.spkr_groups[index]
|
spkr_group = self.spkr_groups[index]
|
||||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||||
|
@ -740,45 +788,50 @@ class Dataset(_Dataset):
|
||||||
if key not in cfg.hdf5:
|
if key not in cfg.hdf5:
|
||||||
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||||
|
|
||||||
|
metadata = cfg.hdf5[key].attrs()
|
||||||
|
|
||||||
text = cfg.hdf5[key]["text"][:]
|
text = cfg.hdf5[key]["text"][:]
|
||||||
resps = cfg.hdf5[key]["audio"][:, :]
|
resps = cfg.hdf5[key]["audio"][:, :]
|
||||||
|
|
||||||
text = torch.from_numpy(text).to(self.text_dtype)
|
text = torch.from_numpy(text).to(self.text_dtype)
|
||||||
resps = torch.from_numpy(resps).to(torch.int16)
|
resps = torch.from_numpy(resps).to(torch.int16)
|
||||||
|
|
||||||
|
lang = metadata["language"] if "language" in metadata else None
|
||||||
|
tone = metadata["tone"] if "tone" in metadata else None
|
||||||
else:
|
else:
|
||||||
resps, metadata = _load_quants(path, return_metadata=True)
|
resps, metadata = _load_quants(path, return_metadata=True)
|
||||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
|
|
||||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
lang = metadata["language"] if "language" in metadata else None
|
||||||
|
tone = metadata["tone"] if "tone" in metadata else None
|
||||||
|
|
||||||
|
if not lang:
|
||||||
|
lang = self.get_language(spkr_group)
|
||||||
|
|
||||||
|
if not tone:
|
||||||
|
tone = "neutral"
|
||||||
|
|
||||||
|
lang = torch.tensor([self.lang_symmap[lang]]).to(torch.uint8)
|
||||||
|
tone = torch.tensor([self.tone_symmap[tone]]).to(torch.uint8)
|
||||||
|
|
||||||
|
naive = True
|
||||||
|
|
||||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
|
||||||
"""
|
|
||||||
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
|
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
|
||||||
# it probably is better to pad with silence instead of just stitching utterances and ruining things
|
# it probably is better to pad with silence instead of just stitching utterances and ruining things
|
||||||
|
"""
|
||||||
|
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||||
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||||
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
|
ignore_paths = []
|
||||||
|
|
||||||
if len(choices) > 0:
|
|
||||||
for _ in range( cfg.dataset.max_resps - 1 ):
|
for _ in range( cfg.dataset.max_resps - 1 ):
|
||||||
sampled_path = random.choice(choices)
|
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
||||||
choices = [*(set(choices) - {sampled_path})]
|
ignore_paths.append(path)
|
||||||
if cfg.dataset.use_hdf5:
|
|
||||||
key = _get_hdf5_path(sampled_path)
|
|
||||||
txt = cfg.hdf5[key]["text"][:]
|
|
||||||
qnt = cfg.hdf5[key]["audio"][:, :]
|
|
||||||
|
|
||||||
txt = np.array( txt )
|
|
||||||
|
|
||||||
txt = torch.from_numpy(txt).to(self.text_dtype)
|
|
||||||
qnt = torch.from_numpy(qnt).to(torch.int16)
|
|
||||||
else:
|
|
||||||
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
|
|
||||||
#txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
|
|
||||||
qnt, metadata = _load_quants(sampled_path, return_metadata=True)
|
|
||||||
txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
|
||||||
|
|
||||||
|
# <s>[original text]</s><s>[new text]</s>
|
||||||
|
if naive:
|
||||||
|
text = torch.concat([ text, txt ])
|
||||||
# <s>[original text] [new text]</s>
|
# <s>[original text] [new text]</s>
|
||||||
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
||||||
|
else:
|
||||||
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
|
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
|
||||||
|
|
||||||
# might be better to decode => concat waveforms with silence in between => reencode
|
# might be better to decode => concat waveforms with silence in between => reencode
|
||||||
|
@ -786,36 +839,48 @@ class Dataset(_Dataset):
|
||||||
resps = torch.concat([ resps, qnt ])
|
resps = torch.concat([ resps, qnt ])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
task = "tts"
|
task = "tts"
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
|
||||||
|
|
||||||
# Disabled until I swap over to a better method
|
|
||||||
"""
|
"""
|
||||||
task = random.choice(self.tasks)
|
|
||||||
|
|
||||||
# ensure a speaker has at least four utterances
|
"""
|
||||||
# default to tts if not
|
resps = resps[:, :cfg.model.resp_levels]
|
||||||
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
|
proms = proms[:, :cfg.model.resp_levels]
|
||||||
task = "tts"
|
"""
|
||||||
noise_scale = 0.25
|
|
||||||
if task == "tts" or task == "tts-c":
|
task = "tts" # random.choice(self.tasks)
|
||||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
|
|
||||||
# demote if the target is too short
|
# Base TTS (text + prompt => output)
|
||||||
if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
|
if task == "tts":
|
||||||
task = "tts"
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
|
||||||
|
# VALL-E Continuous (text + partial output => rest of output)
|
||||||
|
# (this could just be sampled as <text a><text b> + <audio a> => <audio b>, but I need to experiment with it)
|
||||||
|
elif task == "tts-c":
|
||||||
|
# trim a piece of the output response
|
||||||
|
if naive:
|
||||||
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
# VALL-E continuous
|
|
||||||
# ignore if target utterance is shorter than prompt duration
|
|
||||||
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
|
|
||||||
if task == "tts-c":
|
|
||||||
proms = resps[:trim_length, :]
|
proms = resps[:trim_length, :]
|
||||||
resps = resps[trim_length:, :]
|
resps = resps[trim_length:, :]
|
||||||
|
|
||||||
proms = torch.cat( [self.get_task_token(task), proms] )
|
|
||||||
else:
|
else:
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
path, txt, qnt = self.sample_utterance(spkr_name)
|
||||||
|
|
||||||
|
# <s>[original text]</s><s>[new text]</s>
|
||||||
|
if naive:
|
||||||
|
text = torch.concat([ text, txt ])
|
||||||
|
# <s>[original text] [new text]</s>
|
||||||
|
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
||||||
|
else:
|
||||||
|
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
|
||||||
|
|
||||||
|
# set prompt as initial response
|
||||||
|
proms = resps
|
||||||
|
# set target as newly sampled response
|
||||||
|
resps = qnt
|
||||||
|
|
||||||
# noise suppression || speech removal
|
# noise suppression || speech removal
|
||||||
elif task == "ns" or task == "sr":
|
elif task == "ns" or task == "sr":
|
||||||
# sample random noise
|
# sample random noise
|
||||||
|
@ -827,20 +892,19 @@ class Dataset(_Dataset):
|
||||||
# set the target to just be the noise if <sr>
|
# set the target to just be the noise if <sr>
|
||||||
if task == "sr":
|
if task == "sr":
|
||||||
resps = noise
|
resps = noise
|
||||||
# prepend the task token
|
|
||||||
proms = torch.cat( [self.get_task_token(task), proms] )
|
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
text = torch.tensor([1, 2]).to(self.text_dtype)
|
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||||
|
|
||||||
# target speech extraction
|
# target speech extraction
|
||||||
elif task == "tse":
|
elif task == "tse":
|
||||||
# sample a random, clean, utterance for the target speaker
|
# sample a random, clean, utterance for the target speaker
|
||||||
clean_proms = self.sample_prompts(spkr_name, ignore=path)
|
clean_proms = self.sample_prompts(spkr_name, ignore=path)
|
||||||
# sample a random, clean utterance from a different speaker
|
# sample a random, clean utterance from a different speaker
|
||||||
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||||
# overlay the random speaker over the target audio
|
|
||||||
|
|
||||||
|
# overlay the random speaker over the target audio
|
||||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||||
if other_proms.shape[0] == smallest_size:
|
if other_proms.shape[0] == smallest_size:
|
||||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
||||||
|
@ -849,33 +913,30 @@ class Dataset(_Dataset):
|
||||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
||||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||||
|
|
||||||
# stitch together the promps
|
# stitch together the proms
|
||||||
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
|
proms = [
|
||||||
|
clean_proms,
|
||||||
|
task,
|
||||||
|
noisy_proms,
|
||||||
|
]
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
text = torch.tensor([1, 2]).to(self.text_dtype) # <s></s>
|
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||||
|
|
||||||
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
|
||||||
# as I need to get a good clean point to trim into
|
|
||||||
# clean speech editing
|
# clean speech editing
|
||||||
elif task == "cse" or task == "nse":
|
elif task == "cse" or task == "nse":
|
||||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
|
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
||||||
sampled = random.sample([*choices], 4)
|
# as I need to get a good clean point to trim into
|
||||||
|
# instead we'll just sample a bunch of utterances
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
samples = []
|
||||||
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
|
for _ in range( 4 ):
|
||||||
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :]).to(torch.int16) for path in sampled ]
|
sampled = self.sample_utterance(spkr_name, ignore=[s[0] for s in samples])
|
||||||
else:
|
samples.append( sampled )
|
||||||
texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ]
|
|
||||||
qnts = [ _load_quants(path) for path in sampled ]
|
|
||||||
|
|
||||||
# remove <s></s>
|
pre_text, mid_text, post_text, edit_text = [ s[1][1:-1] for s in samples ]
|
||||||
for i in range(len(texts)):
|
pre_prom, mid_prom, post_prom, edit_prom = [ s[2] for s in samples ]
|
||||||
texts[i] = texts[i][1:-1]
|
|
||||||
|
|
||||||
pre_text, mid_text, post_text, edit_text = texts
|
|
||||||
pre_prom, mid_prom, post_prom, edit_prom = qnts
|
|
||||||
|
|
||||||
# randomly drop out pre
|
# randomly drop out pre
|
||||||
if random.random() < 0.125:
|
if random.random() < 0.125:
|
||||||
|
@ -888,11 +949,11 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# create new text
|
# create new text
|
||||||
text = torch.cat(
|
text = torch.cat(
|
||||||
[ torch.Tensor( [ 1 ] ).to(dtype=self.text_dtype) ] + # <s>
|
[ torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype) ] + # <s>
|
||||||
([ pre_text, torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
|
([ pre_text, torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
|
||||||
[ edit_text ] + # 'edit text'
|
[ edit_text ] + # 'edit text'
|
||||||
([ torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
|
([ torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
|
||||||
[ torch.Tensor( [ 2 ] ).to(dtype=self.text_dtype) ] # </s>
|
[ torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype) ] # </s>
|
||||||
)
|
)
|
||||||
|
|
||||||
if task == "nse":
|
if task == "nse":
|
||||||
|
@ -916,17 +977,16 @@ class Dataset(_Dataset):
|
||||||
mid_prom = noise_proms( mid_prom )
|
mid_prom = noise_proms( mid_prom )
|
||||||
post_prom = noise_proms( post_prom )
|
post_prom = noise_proms( post_prom )
|
||||||
edit_prom = noise_proms( edit_prom )
|
edit_prom = noise_proms( edit_prom )
|
||||||
else:
|
|
||||||
mid_prom = self.get_task_token("mask")
|
|
||||||
|
|
||||||
# create new proms
|
# create new prom
|
||||||
proms = torch.cat(
|
proms = [
|
||||||
([ pre_prom ] if pre_prom is not None else []) +
|
pre_prom,
|
||||||
[self.get_task_token("soe")] +
|
"<soe>",
|
||||||
[ mid_prom ] + # is <mask> if task is CSE
|
"<mask>" if task == "cse" else mid_prom,
|
||||||
[self.get_task_token("eoe")] +
|
"<eoe>",
|
||||||
([ post_prom ] if post_prom is not None else [])
|
post_prom,
|
||||||
)
|
]
|
||||||
|
|
||||||
# create new resp
|
# create new resp
|
||||||
resps = torch.cat(
|
resps = torch.cat(
|
||||||
([ pre_prom ] if pre_prom is not None else []) +
|
([ pre_prom ] if pre_prom is not None else []) +
|
||||||
|
@ -935,35 +995,6 @@ class Dataset(_Dataset):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Undefined task: {task}')
|
raise Exception(f'Undefined task: {task}')
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
# emulate SVC
|
|
||||||
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
|
|
||||||
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
|
|
||||||
|
|
||||||
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
|
|
||||||
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
|
|
||||||
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
|
|
||||||
elif task == "svc":
|
|
||||||
# sample a random, clean utterance for the target speaker
|
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
||||||
# sample a reference clip from a different speaker
|
|
||||||
ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name]))
|
|
||||||
#
|
|
||||||
resps =
|
|
||||||
# stitch together the promps
|
|
||||||
proms = torch.cat( [proms, self.get_task_token(task), ref_proms] )
|
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
|
||||||
if random.random() < 0.5:
|
|
||||||
text = torch.tensor([1, 2]).to(self.text_dtype)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# trim to fit to requested prom/resps levels
|
|
||||||
proms = proms[:, :cfg.model.resp_levels]
|
|
||||||
resps = resps[:, :cfg.model.resp_levels]
|
|
||||||
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
|
@ -972,6 +1003,7 @@ class Dataset(_Dataset):
|
||||||
spkr_id=spkr_id,
|
spkr_id=spkr_id,
|
||||||
task=task,
|
task=task,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
|
tone=tone,
|
||||||
text=text,
|
text=text,
|
||||||
proms=proms,
|
proms=proms,
|
||||||
resps=resps,
|
resps=resps,
|
||||||
|
|
|
@ -94,6 +94,7 @@ class AR_NAR(Base):
|
||||||
proms_list: list[Tensor],
|
proms_list: list[Tensor],
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
task_list: list[Tensor] | None = None,
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
@ -117,6 +118,10 @@ class AR_NAR(Base):
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
|
||||||
|
# generate task list if not provided
|
||||||
|
if task_list is None:
|
||||||
|
task_list = [ "tts" for _ in range(batch_size) ]
|
||||||
|
|
||||||
# is training or NAR
|
# is training or NAR
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||||
|
@ -127,15 +132,8 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# is training
|
# is training
|
||||||
if training:
|
if training:
|
||||||
# to-do: make this YAML configurable
|
|
||||||
def sample_task():
|
|
||||||
return "tts"
|
|
||||||
|
|
||||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||||
|
|
||||||
# generate task list to train against
|
|
||||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
|
||||||
|
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||||
|
|
||||||
|
@ -164,6 +162,8 @@ class AR_NAR(Base):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
# other tasks might have the prom be a list and this is just the easiest way to acknowledge that
|
||||||
|
if task_list[i] == "tts":
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||||
|
@ -261,7 +261,7 @@ class AR_NAR(Base):
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
stop_token = self.stop_token
|
stop_token = self.stop_token
|
||||||
task_list = [ "tts" for _ in range(batch_size) ]
|
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
mirostat = [
|
mirostat = [
|
||||||
|
|
|
@ -33,6 +33,9 @@ from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_to
|
||||||
|
|
||||||
from ..emb.qnt import encode_as_embedding
|
from ..emb.qnt import encode_as_embedding
|
||||||
|
|
||||||
|
# yuck, kind of needed
|
||||||
|
from ..data import get_task_symmap
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||||
"""
|
"""
|
||||||
|
@ -877,7 +880,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# Base-line TTS task
|
# Base-line TTS task
|
||||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||||
if task_type == "tts":
|
if task_type in ["tts", "tts-c", "ns", "sr"]:
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None:
|
if text_list is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
@ -915,7 +918,6 @@ class Base(nn.Module):
|
||||||
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
||||||
quant_levels[i] = 0
|
quant_levels[i] = 0
|
||||||
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
||||||
|
|
||||||
# insert input audio prompt
|
# insert input audio prompt
|
||||||
if proms_list is not None:
|
if proms_list is not None:
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
|
@ -938,10 +940,22 @@ class Base(nn.Module):
|
||||||
inputs: list,
|
inputs: list,
|
||||||
quant_levels: int | list[int] | Tensor | None = None
|
quant_levels: int | list[int] | Tensor | None = None
|
||||||
):
|
):
|
||||||
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
|
def prompt_input_to_embedding( input, quant_level ):
|
||||||
|
if isinstance(inputs, str):
|
||||||
|
return self.tasks_emb( get_task_symmap( input ) ) if self.langs_emb is None else None
|
||||||
|
|
||||||
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
|
if self.version <= 4:
|
||||||
|
return self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
|
||||||
|
|
||||||
|
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
|
||||||
|
|
||||||
x_list = []
|
x_list = []
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = []
|
batch = []
|
||||||
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
||||||
|
task_type = "tts"
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
embedding = None
|
embedding = None
|
||||||
|
@ -950,7 +964,7 @@ class Base(nn.Module):
|
||||||
if name == "task":
|
if name == "task":
|
||||||
# noop
|
# noop
|
||||||
# *maybe* inject a token for specifying task type
|
# *maybe* inject a token for specifying task type
|
||||||
...
|
task_type = input
|
||||||
continue
|
continue
|
||||||
elif name == "text":
|
elif name == "text":
|
||||||
embedding = self.text_emb( input )
|
embedding = self.text_emb( input )
|
||||||
|
@ -959,14 +973,8 @@ class Base(nn.Module):
|
||||||
elif name == "lang" and self.langs_emb is not None:
|
elif name == "lang" and self.langs_emb is not None:
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
if self.version <= 4:
|
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms ] )
|
||||||
embedding = self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
|
|
||||||
else:
|
|
||||||
if quant_level == 0:
|
|
||||||
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :1], offset = 0 )
|
|
||||||
else:
|
|
||||||
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :quant_level], offset = 0 )
|
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
|
@ -1034,6 +1042,17 @@ class Base(nn.Module):
|
||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
):
|
):
|
||||||
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
if isinstance(inputs, str):
|
||||||
|
return get_task_symmap( input )
|
||||||
|
|
||||||
|
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||||
|
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||||
|
return torch.full_like(input[..., 0], self.ignore_index)
|
||||||
|
|
||||||
|
return input if input.dim() == 1 else input[:, quant_level]
|
||||||
|
|
||||||
# old, "naive" way, no loss factoring
|
# old, "naive" way, no loss factoring
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
target_list = []
|
target_list = []
|
||||||
|
@ -1046,12 +1065,8 @@ class Base(nn.Module):
|
||||||
if name == "task":
|
if name == "task":
|
||||||
task_list.append( input )
|
task_list.append( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) )
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
|
||||||
# we *CAN* directly map to proms
|
|
||||||
else:
|
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||||
|
@ -1119,7 +1134,8 @@ class Base(nn.Module):
|
||||||
input = input if input.dim() == 1 else input[:, quant_level]
|
input = input if input.dim() == 1 else input[:, quant_level]
|
||||||
# select prom level
|
# select prom level
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
input = input[:, quant_level]
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
||||||
# meta-input, no corresponding token at the moment
|
# meta-input, no corresponding token at the moment
|
||||||
elif name == "task":
|
elif name == "task":
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -92,6 +92,7 @@ class NAR(Base):
|
||||||
proms_list: list[Tensor],
|
proms_list: list[Tensor],
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
task_list: list[Tensor] | None = None,
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
|
@ -64,6 +64,8 @@ def train_feeder(engine, batch):
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
resps_list=batch["resps"],
|
resps_list=batch["resps"],
|
||||||
lang_list=batch["lang"],
|
lang_list=batch["lang"],
|
||||||
|
tone_list=batch["tone"],
|
||||||
|
task_list=batch["task"],
|
||||||
|
|
||||||
training=True,
|
training=True,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user