diff --git a/vall_e/config.py b/vall_e/config.py index 8e72771..d487d4d 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -248,11 +248,6 @@ class ModelExperimentalSettings: rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary unified_position_ids: bool = True # False will generate position IDs partitioned for each section tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing - - # performs token dropout to compensate for errors - token_dropout_error: float = 0.0 # probability to nudge a token by ±1 - token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value - token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0 causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions # VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model @@ -270,11 +265,19 @@ class ModelExperimentalSettings: classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio + # these technically should be as hyperparameters + # performs token dropout to compensate for errors + token_dropout_error: float = 0.0 # probability to nudge a token by ±1 + token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value + token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0 + # these technically should be as hyperparameters # classifier-free guidance training settings cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training + use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead + # failed experiment layerskip: bool = False # layerskip compatible model (or training for) #layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters) diff --git a/vall_e/inference.py b/vall_e/inference.py index ea205ac..a43383d 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -23,7 +23,7 @@ from .config import cfg, Config from .models import get_models from .models.lora import enable_lora from .engines import load_engines, deepspeed_available -from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split +from .data import get_phone_symmap, get_lang_symmap, tokenize, text_tokenize, sentence_split from .models import download_model, DEFAULT_MODEL_PATH if deepspeed_available: @@ -412,7 +412,7 @@ class TTS(): model = model_ar if model_ar is not None else model_nar if model is not None: text_list = model( - text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"], + text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task], disable_tqdm=not use_tqdm, use_lora=use_lora, **sampling_kwargs, @@ -423,6 +423,35 @@ class TTS(): text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ] return text_list[0] + elif task in ["phn", "un-phn"]: + lang = self.encode_lang( language ) + lang = to_device(lang, device=self.device, dtype=torch.uint8) + + with torch.autocast(self.device, dtype=dtype, enabled=amp): + model = model_ar if model_ar is not None else model_nar + if task == "phn": + text_list = None + raw_text_list = [ torch.tensor( text_tokenize( text ), device=self.device, dtype=torch.int16) ] + output_tokenizer = cfg.tokenizer + else: + text_list = [ torch.tensor( tokenize( text ), device=self.device, dtype=torch.int16) ] + raw_text_list = None + output_tokenizer = cfg.text_tokenizer + + if model is not None: + text_list = model( + text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task], + disable_tqdm=not use_tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) + else: + raise Exception("!") + + text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ] + + return text_list[0] + # stuff for rolling context prefix_context = None diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0c8391d..4c4bbc1 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -29,7 +29,7 @@ from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs from .lora import enable_lora from ..samplers import cfg_logits -text_task = [ "stt" ] +text_task = [ "stt", "phn", "un-phn" ] class AR_NAR(Base): # yikes @@ -40,23 +40,28 @@ class AR_NAR(Base): # a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training def forward_train( self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor], - task_list: list[Tensor] | None = None, + + text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, raw_text_list: list[Tensor] | None = None, ): # deduce batch_size - if text_list is not None: - default_task = "tts" + if text_list: device = text_list[0].device batch_size = len(text_list) - else: - default_task = "stt" + elif raw_text_list: + device = raw_text_list[0].device + batch_size = len(raw_text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: device = resps_list[0].device batch_size = len(resps_list) @@ -161,10 +166,7 @@ class AR_NAR(Base): # only apply stop token for RVQ level 0 if quant_level <= 0 and timesteps[i] is None: # append stop tokens for AR - if task in text_task: - #text_list[i] = torch.cat([ resps, text_stop_sequence ]) - ... - else: + if task not in text_task: resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) if task == "len": @@ -174,6 +176,7 @@ class AR_NAR(Base): if task not in text_task + ["len"]: drop_text = False drop_audio = False + swap_text = False if random.random() < cfg_prom_dropout_p: drop_audio = True @@ -181,6 +184,9 @@ class AR_NAR(Base): if random.random() < cfg_cond_dropout_p: drop_audio = True drop_text = True + + if random.random() < use_raw_text_p and raw_text_list[i] is not None: + swap_text = True if drop_text: text_list[i] = text_start_stop_sequence @@ -188,6 +194,9 @@ class AR_NAR(Base): if drop_audio: proms_list[i] = None + if swap_text and not drop_text: + text_list[i] = None + inputs = self.inputs( text_list=text_list, proms_list=proms_list, @@ -209,14 +218,16 @@ class AR_NAR(Base): def forward_nar_masked( self, - text_list: list[Tensor], - proms_list: list[Tensor], + task_list: list[Tensor] | None = None, + + text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, - task_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, + raw_text_list: list[Tensor] | None = None, disable_tqdm=False, use_lora=None, @@ -420,14 +431,17 @@ class AR_NAR(Base): def forward_nar( self, - text_list: list[Tensor], - proms_list: list[Tensor], + task_list: list[Tensor] | None = None, + + text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, - task_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, + + raw_text_list: list[Tensor] | None = None, disable_tqdm=False, use_lora=None, @@ -447,12 +461,16 @@ class AR_NAR(Base): ) # deduce batch_size - if text_list is not None: - default_task = "tts" + if text_list: device = text_list[0].device batch_size = len(text_list) - else: - default_task = "stt" + elif raw_text_list: + device = raw_text_list[0].device + batch_size = len(raw_text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: device = resps_list[0].device batch_size = len(resps_list) @@ -534,25 +552,31 @@ class AR_NAR(Base): def forward_ar( self, - text_list: list[Tensor], - proms_list: list[Tensor], + task_list: list[Tensor], + + text_list: list[Tensor] | None = None, + raw_text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, - - task_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, + disable_tqdm=False, use_lora=None, **sampling_kwargs, ): # deduce batch_size - if text_list is not None: - default_task = "tts" + if text_list: device = text_list[0].device batch_size = len(text_list) - else: - default_task = "stt" + elif raw_text_list: + device = raw_text_list[0].device + batch_size = len(raw_text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: device = resps_list[0].device batch_size = len(resps_list) @@ -590,13 +614,17 @@ class AR_NAR(Base): len_list = sequence_list inputs = self.inputs( + task_list=task_list, + text_list=text_list, proms_list=proms_list, resps_list=resps_list, + lang_list=lang_list, tone_list=tone_list, len_list=len_list, - task_list=task_list, + raw_text_list=raw_text_list, + quant_levels=quant_levels, ) @@ -627,7 +655,6 @@ class AR_NAR(Base): # convert tokens into int return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ] - # STT start_slice = [ 0 for _ in range(batch_size) ] sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() @@ -684,19 +711,29 @@ class AR_NAR(Base): # get next in sequence iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm) for n in iterator: - # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it - text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] - resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + if batch_size == 1 and task_list[0] in ["phn", "un-phn"]: + text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ] + raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ] + else: + text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] + resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + + print( text_list, raw_text_list ) + quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] inputs = self.inputs( + task_list=task_list, + text_list=text_list, proms_list=proms_list, resps_list=resps_list, + lang_list=lang_list, tone_list=tone_list, len_list=len_list, - task_list=task_list, + raw_text_list=raw_text_list, + quant_levels=quant_levels, ) @@ -816,11 +853,12 @@ class AR_NAR(Base): def forward( self, - text_list: list[Tensor], - proms_list: list[Tensor], + task_list: list[Tensor] | None = None, + + text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, - task_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, @@ -833,19 +871,20 @@ class AR_NAR(Base): **sampling_kwargs, ): # deduce batch_size - if text_list is not None: - default_task = "tts" + # deduce batch_size + if text_list: device = text_list[0].device batch_size = len(text_list) - else: - default_task = "stt" + elif raw_text_list: + device = raw_text_list[0].device + batch_size = len(raw_text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: device = resps_list[0].device batch_size = len(resps_list) - # generate task list if not provided - if task_list is None: - task_list = [ default_task for _ in range(batch_size) ] - # implicitly set for training if training is None and text_list is not None and resps_list is not None: n_levels_set = {r.shape[-1] for r in resps_list} @@ -856,10 +895,12 @@ class AR_NAR(Base): # is training if training: return self.forward_train( + task_list=task_list, + text_list=text_list, proms_list=proms_list, resps_list=resps_list, - task_list=task_list, + lang_list=lang_list, tone_list=tone_list, len_list=len_list, @@ -869,13 +910,17 @@ class AR_NAR(Base): # is NAR if (len_list is not None or resps_list is not None) and text_list is not None: return self.forward_nar( + task_list=task_list, + text_list=text_list, proms_list=proms_list, resps_list=resps_list, - task_list=task_list, + lang_list=lang_list, tone_list=tone_list, len_list=len_list, + raw_text_list=raw_text_list, + disable_tqdm=disable_tqdm, use_lora=use_lora, **sampling_kwargs, @@ -883,13 +928,17 @@ class AR_NAR(Base): # is AR return self.forward_ar( + task_list=task_list, + text_list=text_list, proms_list=proms_list, resps_list=resps_list, - task_list=task_list, + lang_list=lang_list, tone_list=tone_list, len_list=len_list, + raw_text_list=raw_text_list, + disable_tqdm=disable_tqdm, use_lora=use_lora, **sampling_kwargs, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index a77f29c..3f183e9 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -937,21 +937,32 @@ class Base(nn.Module): # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation def inputs( self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor], + text_list: list[Tensor] | None = None, + raw_text_list: list[Tensor] | None = None, + + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, task_list: list[str] | None = None, time_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, quant_levels: int | list[int] | Tensor | None = None ): - device = text_list[0].device - batch_size = len(text_list) + if text_list: + device = text_list[0].device + batch_size = len(text_list) + elif raw_text_list: + device = raw_text_list[0].device + batch_size = len(raw_text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: + device = resps_list[0].device + batch_size = len(resps_list) inputs = [ [] for _ in range(batch_size) ] for i in range(batch_size): @@ -973,6 +984,8 @@ class Base(nn.Module): # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) + elif raw_text_list is not None and raw_text_list[i] is not None: + inputs[i].append( ( "raw_text", raw_text_list[i] ) ) # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) @@ -1022,6 +1035,8 @@ class Base(nn.Module): # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) + elif raw_text_list is not None and raw_text_list[i] is not None: + inputs[i].append( ( "raw_text", raw_text_list[i] ) ) # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) @@ -1070,6 +1085,8 @@ class Base(nn.Module): # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) + if self.rvq_l_emb is not None and not self.interleave: + inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) @@ -1084,6 +1101,8 @@ class Base(nn.Module): # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) + if self.rvq_l_emb is not None and not self.interleave: + inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the text prompt if raw_text_list is not None and raw_text_list[i] is not None: inputs[i].append( ( "raw_text", raw_text_list[i] ) ) @@ -1197,7 +1216,7 @@ class Base(nn.Module): embedding = self.text_emb( input ) device = embedding.device - elif name == "raw_text": + elif name == "raw_text" and self.raw_text_emb is not None: embedding = self.raw_text_emb( input ) device = embedding.device @@ -1643,6 +1662,10 @@ class Base(nn.Module): if quant_levels is None: quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ] + # inputs don't have quant levels added, pure AR + if len(quant_levels) != len(inputs): + quant_levels = [ 0 for _ in range(len(inputs)) ] + x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, mask = list_to_tensor(x_list) @@ -1652,10 +1675,6 @@ class Base(nn.Module): device = x.device batch_size = len(x_list) - # pure AR - if quant_levels is None: - quant_levels = [ 0 for _ in range(batch_size) ] - # we only need hidden states if we're training with layerskip if self.layerskip and training: output_hidden_states = True