added initial support for languages (still testing, marked as model version 3), added experimental 'context extend by limiting the resp context' (untested)
This commit is contained in:
parent
6045cbce94
commit
8740cdefc6
|
@ -17,6 +17,7 @@ def main():
|
|||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||
parser.add_argument("--max-ar-context", type=int, default=-1)
|
||||
|
||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=1.0)
|
||||
|
@ -46,6 +47,7 @@ def main():
|
|||
out_path=args.out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
max_ar_context=args.max_ar_context,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
|
|
|
@ -120,6 +120,9 @@ class Dataset:
|
|||
temp: list[Path] = field(default_factory=lambda: [])
|
||||
|
||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'"
|
||||
|
||||
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
|
||||
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
|
@ -164,8 +167,8 @@ class Model:
|
|||
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
||||
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
|
||||
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
||||
tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
langs: int = 0 # defined languages
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
langs: int = 1 # defined languages
|
||||
arch_type: str = "retnet" # or "transformer""
|
||||
training: bool = True # unneeded now
|
||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||
|
@ -518,6 +521,10 @@ class Config(_Config):
|
|||
def get_spkr(self):
|
||||
return eval(self.dataset.speaker_name_getter)
|
||||
|
||||
@cached_property
|
||||
def get_spkr_group(self):
|
||||
return eval(self.dataset.speaker_group_getter)
|
||||
|
||||
@cached_property
|
||||
def diskcache(self):
|
||||
if self.cfg_path is not None and self.dataset.cache:
|
||||
|
|
|
@ -33,20 +33,26 @@ def get_phone_symmap():
|
|||
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
||||
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
||||
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185}
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, 'ᵝ': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, 'oˌ': 220, 'eˈ': 221, 'ʍ': 222, 'eˌ': 223, 'uˌ': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
|
||||
return symmap
|
||||
|
||||
def get_lang_symmap():
|
||||
symmap = {
|
||||
"en": 0,
|
||||
"ja": 1,
|
||||
}
|
||||
return symmap
|
||||
|
||||
def get_task_symmap():
|
||||
start = 1024
|
||||
symmap = {
|
||||
"<tts>": -100,
|
||||
"<ns>": start + 0,
|
||||
"<sr>": start + 1,
|
||||
"<tse>": start + 2,
|
||||
"<soe>": start + 3,
|
||||
"<mask>": start + 4,
|
||||
"<eoe>": start + 5,
|
||||
"<tts-c>": start + 6,
|
||||
"<tts>": 0,
|
||||
"<tts-c>": 1,
|
||||
"<ns>": 2,
|
||||
"<sr>": 3,
|
||||
"<tse>": 4,
|
||||
"<soe>": 5,
|
||||
"<mask>": 6,
|
||||
"<eoe>": 7,
|
||||
}
|
||||
return symmap
|
||||
|
||||
|
@ -105,7 +111,9 @@ def _get_hdf5_path(path):
|
|||
path = str(path)
|
||||
if path[:2] != "./":
|
||||
path = f'./{path}'
|
||||
return path.replace(cfg.cfg_path, "")
|
||||
|
||||
res = path.replace(cfg.cfg_path, "")
|
||||
return res
|
||||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
@ -206,6 +214,7 @@ class Dataset(_Dataset):
|
|||
|
||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||
self.spkr_symmap = self._get_spkr_symmap()
|
||||
self.lang_symmap = self._get_lang_symmap()
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
|
||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||
|
@ -227,6 +236,21 @@ class Dataset(_Dataset):
|
|||
res = cfg.get_spkr(path)
|
||||
return res
|
||||
|
||||
def get_speaker_group(self, path):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
res = cfg.get_spkr_group(path)
|
||||
return res
|
||||
|
||||
def get_language(self, speaker_group):
|
||||
lang = "en"
|
||||
for k, v in cfg.dataset.speaker_languages.items():
|
||||
if speaker_group in v:
|
||||
lang = k
|
||||
break
|
||||
|
||||
return lang
|
||||
|
||||
@cached_property
|
||||
def spkrs(self):
|
||||
return sorted({self.get_speaker(path) for path in self.paths})
|
||||
|
@ -257,13 +281,18 @@ class Dataset(_Dataset):
|
|||
def _get_spkr_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkrs)}
|
||||
|
||||
def _get_lang_symmap(self):
|
||||
return get_lang_symmap()
|
||||
|
||||
def _get_task_symmap(self):
|
||||
return get_task_symmap()
|
||||
|
||||
"""
|
||||
def get_task_token( self, token, levels=cfg.models.max_levels ):
|
||||
if not hasattr(self, "task_symmap"):
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
||||
"""
|
||||
|
||||
def sample_noise(self):
|
||||
path = random.choice(self.noise_paths)
|
||||
|
@ -340,6 +369,7 @@ class Dataset(_Dataset):
|
|||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
resps = cfg.hdf5[key]["audio"][:, :]
|
||||
|
||||
|
@ -351,6 +381,9 @@ class Dataset(_Dataset):
|
|||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
lang = self.lang_symmap[ self.get_language(spkr_group) ]
|
||||
|
||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
|
||||
|
@ -565,6 +598,7 @@ class Dataset(_Dataset):
|
|||
spkr_name=spkr_name,
|
||||
spkr_id=spkr_id,
|
||||
task=task,
|
||||
lang=lang,
|
||||
text=text,
|
||||
proms=proms,
|
||||
resps=resps,
|
||||
|
@ -799,8 +833,8 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
# write symmap
|
||||
if "symmap" in hf:
|
||||
del hf['symmap']
|
||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||
|
||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||
hf.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -149,6 +149,7 @@ class TTS():
|
|||
text,
|
||||
references,
|
||||
max_ar_steps=6 * 75,
|
||||
max_ar_context=-1,
|
||||
max_nar_levels=7,
|
||||
input_prompt_length=0.0,
|
||||
ar_temp=0.95,
|
||||
|
@ -176,7 +177,7 @@ class TTS():
|
|||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
resps_list = self.ar(
|
||||
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps,
|
||||
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
|
|
|
@ -39,11 +39,11 @@ class AR(Base):
|
|||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
return cfg.models.ar.tasks
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
return cfg.models.langs
|
||||
return cfg.models.ar.langs
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
|
@ -103,6 +103,7 @@ class AR(Base):
|
|||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor] | None = None,
|
||||
max_steps: int = 1000,
|
||||
max_resp_context: int = -1,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
|
@ -149,7 +150,11 @@ class AR(Base):
|
|||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
if max_resp_context > 0:
|
||||
resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] )
|
||||
else:
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
|
|
|
@ -43,7 +43,11 @@ class AR_NAR(Base):
|
|||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
return cfg.models.ar_nar.tasks
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
return cfg.models.ar_nar.langs
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
|
@ -86,8 +90,13 @@ class AR_NAR(Base):
|
|||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 7,
|
||||
max_resp_context: int = -1,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
|
@ -184,7 +193,13 @@ class AR_NAR(Base):
|
|||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
||||
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
||||
if max_resp_context > 0:
|
||||
resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] )
|
||||
else:
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
|
|
|
@ -191,7 +191,7 @@ class Base(nn.Module):
|
|||
cat = torch.cat
|
||||
else:
|
||||
cat = partial(_join, sep=sep)
|
||||
return [*map(cat, zip(*l))]
|
||||
return [*map(cat, zip([x for x in l if x is not None]))]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -229,8 +229,9 @@ class Base(nn.Module):
|
|||
# [1025] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
|
||||
|
||||
# self.langs_emb = Embedding(self.n_langs, d_model)
|
||||
# self.tasks_emb = Embedding(self.n_tasks, d_model)
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(self.n_langs, d_model)
|
||||
self.tasks_emb = Embedding(self.n_tasks, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -291,25 +292,27 @@ class Base(nn.Module):
|
|||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
|
||||
#langs_list: list[Tensor] | None = None,
|
||||
#tasks_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: Tensor | None = None,
|
||||
state: dict | None = None,
|
||||
):
|
||||
batch_size = len(text_list)
|
||||
|
||||
if self.langs_emb is None:
|
||||
langs_list = None
|
||||
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
#self.langs_emb(langs_list),
|
||||
self.langs_emb(lang_list) if lang_list is not None else None,
|
||||
self.proms_emb(proms_list),
|
||||
#self.tasks_emb(tasks_list),
|
||||
self.resps_emb(resps_list, quant_levels),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
batch_size = len(text_list)
|
||||
device = x.device
|
||||
|
||||
if state is not None and self.arch_type == "retnet":
|
||||
|
@ -349,7 +352,12 @@ class Base(nn.Module):
|
|||
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
||||
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
|
||||
# remake input sequence
|
||||
text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep )
|
||||
text_prom_list = self._samplewise_merge_tensors(
|
||||
text_list,
|
||||
lang_list,
|
||||
prom_list,
|
||||
sep=ignore_sep
|
||||
)
|
||||
|
||||
# process each batch
|
||||
for i in range(len(text_prom_list)):
|
||||
|
|
|
@ -37,11 +37,11 @@ class NAR(Base):
|
|||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
return cfg.models.nar.tasks
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
return cfg.models.langs
|
||||
return cfg.models.nar.langs
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
|
|
Loading…
Reference in New Issue
Block a user