diff --git a/vall_e/config.py b/vall_e/config.py index 180260f..2143897 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -382,6 +382,12 @@ class Model: def text_tokens(self): if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"): return self.size['text_tokens'] + return 8575 + + @property + def phoneme_tokens(self): + if isinstance(self.size, dict) and hasattr(self.size, "phoneme_tokens"): + return self.size['phoneme_tokens'] return 256 @property diff --git a/vall_e/data.py b/vall_e/data.py index 963cb24..9200bc2 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1611,10 +1611,11 @@ class Dataset(_Dataset): task=task, lang=lang, tone=tone, - text=text, proms=proms, resps=resps, - raw_text=raw_text, + + phns=text, + text=raw_text, metadata=metadata, ) diff --git a/vall_e/inference.py b/vall_e/inference.py index 3e63fde..14b127d 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -306,7 +306,7 @@ class TTS(): seed = set_seed(seed) batch_size = len(texts) input_kwargs = dict( - text_list=texts, + phns_list=texts, proms_list=proms, lang_list=langs, disable_tqdm=not use_tqdm, @@ -421,8 +421,8 @@ class TTS(): with torch.autocast(self.device, dtype=dtype, enabled=amp): model = model_ar if model_ar is not None else model_nar if model is not None: - text_list = model( - text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task], + phns_list = model( + phns_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task], disable_tqdm=not use_tqdm, use_lora=use_lora, **sampling_kwargs, @@ -430,9 +430,9 @@ class TTS(): else: raise Exception("!") - text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ] + phns_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ] - return text_list[0] + return phns_list[0] elif task in ["phn", "un-phn"]: lang = self.encode_lang( language ) lang = to_device(lang, device=self.device, dtype=torch.uint8) @@ -440,17 +440,17 @@ class TTS(): with torch.autocast(self.device, dtype=dtype, enabled=amp): model = model_ar if model_ar is not None else model_nar if task == "phn": - text_list = None - raw_text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ] + phns_list = None + text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ] output_tokenizer = cfg.tokenizer else: - text_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ] - raw_text_list = None + phns_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ] + text_list = None output_tokenizer = cfg.text_tokenizer if model is not None: - text_list = model( - text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task], + phns_list = model( + phns_list=phns_list, text_list=text_list, lang_list=[lang], task_list=[task], disable_tqdm=not use_tqdm, use_lora=use_lora, **sampling_kwargs, @@ -458,9 +458,9 @@ class TTS(): else: raise Exception("!") - text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ] + phns_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ] - return text_list[0] + return phns_list[0] # stuff for rolling context @@ -504,8 +504,8 @@ class TTS(): with torch.autocast(self.device, dtype=dtype, enabled=amp): input_kwargs = dict( - text_list=[phns] if phonemize else None, - raw_text_list=[phns] if not phonemize else None, + phns_list=[phns] if phonemize else None, + text_list=[phns] if not phonemize else None, proms_list=[prom], lang_list=[lang], disable_tqdm=not use_tqdm, diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index f635889..7b33b10 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -59,11 +59,18 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ): def get_model(config, training=True, **model_kwargs): - from .ar_nar import AR_NAR # import here because reasons - name = config.name - model = AR_NAR( - n_text_tokens=config.text_tokens, + # crunge + if config.version < 7: + from .ar_nar import AR_NAR + ModelClass = AR_NAR + else: + from .ar_nar_v2 import AR_NAR_V2 + ModelClass = AR_NAR_V2 + + cfg_kwargs = dict( + n_phn_tokens=config.phoneme_tokens, n_audio_tokens=config.audio_tokens, + n_text_tokens=config.text_tokens, d_model=config.dim, n_heads=config.heads, n_layers=config.layers, @@ -75,9 +82,11 @@ def get_model(config, training=True, **model_kwargs): training = training, config = config, - **model_kwargs ) + name = config.name + model = ModelClass(**(cfg_kwargs | model_kwargs)) + _logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") return model diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f6bc4d6..afd80be 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -42,22 +42,22 @@ class AR_NAR(Base): self, task_list: list[Tensor] | None = None, - text_list: list[Tensor] | None = None, + phns_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, ): # deduce batch_size - if text_list: + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: device = text_list[0].device batch_size = len(text_list) - elif raw_text_list: - device = raw_text_list[0].device - batch_size = len(raw_text_list) elif proms_list: device = proms_list[0].device batch_size = len(proms_list) @@ -75,9 +75,6 @@ class AR_NAR(Base): token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels # RVQ levels to apply masking training on masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels - - if self.version >= 7: - masking_train_rvq_levels = [0,self.n_resp_levels] if cfg.audio_backend == "nemo": rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ] @@ -134,13 +131,12 @@ class AR_NAR(Base): timesteps[i] = (timesteps[i] * 0.6) + 0.2 # trim resps to only contain all levels below the target level - if self.version < 7: - resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] + resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] # tensor to cat for RVQ level 0 text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16) text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) - audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16) + audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) # final validations and stuff for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): @@ -175,10 +171,10 @@ class AR_NAR(Base): """ # only apply stop token for RVQ level 0 - if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally): + if (quant_level <= 0 and timesteps[i] is None) or (self.predict_causally): # append stop tokens for AR if task not in text_task: - resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) + resps_list[i] = torch.cat([ resps, audio_stop_sequence.repeat((1, resps.shape[-1])) ]) if task == "len": quant_levels[i] = 0 @@ -196,26 +192,26 @@ class AR_NAR(Base): drop_audio = True drop_text = True - if random.random() < use_raw_text_p and raw_text_list[i] is not None: + if random.random() < use_raw_text_p and text_list[i] is not None: swap_text = True if drop_text: - text_list[i] = text_start_stop_sequence + phns_list[i] = text_start_stop_sequence if drop_audio: proms_list[i] = None if swap_text and not drop_text: - text_list[i] = None + phns_list[i] = None inputs = self.inputs( - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, task_list=task_list, - raw_text_list=raw_text_list, + text_list=text_list, time_list=timesteps, quant_levels=quant_levels, @@ -231,22 +227,22 @@ class AR_NAR(Base): task_list: list[Tensor] | None = None, - text_list: list[Tensor] | None = None, + phns_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, quant_levels: list[int] | None = None, disable_tqdm=False, use_lora=None, **sampling_kwargs, ): - device = text_list[0].device - batch_size = len(text_list) + device = phns_list[0].device + batch_size = len(phns_list) if quant_levels is None: level = 0 @@ -304,7 +300,7 @@ class AR_NAR(Base): prefix_context = sampling_kwargs.get("prefix_context", None) # we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring if prefix_context is not None: - text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], text_list ) ] + phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], phns_list ) ] prefix_resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_context[1] ] # if we're denoising from an existing sequence @@ -379,7 +375,7 @@ class AR_NAR(Base): # setup inputs inputs = super().inputs( - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=input_resps_list, lang_list=lang_list, @@ -396,7 +392,7 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - text_list=null_text, + phns_list=null_text, proms_list=null_prom, resps_list=input_resps_list, lang_list=lang_list, @@ -446,188 +442,11 @@ class AR_NAR(Base): return resps_list - # handles doing demasking inferencing in parallel to inference all tokens - # it works if the underlying model is trained properly (which is a pain) - def forward_nar_masked_parallel( - self, - - task_list: list[Tensor] | None = None, - - text_list: list[Tensor] | None = None, - proms_list: list[Tensor] | None = None, - resps_list: list[Tensor] | None = None, - - lang_list: list[Tensor] | None = None, - tone_list: list[Tensor] | None = None, - len_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, - - disable_tqdm=False, - use_lora=None, - **sampling_kwargs, - ): - device = text_list[0].device - batch_size = len(text_list) - - level = 0 - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - - # convert (N)AR specific args - sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) - - min_length = sampling_kwargs.pop("min_duration", 1) - max_length = sampling_kwargs.pop("max_duration", 500) - max_steps = sampling_kwargs.get("max_steps", 25) - refine_on_stop = sampling_kwargs.get("refine_on_stop", False) - entropix_sampling = sampling_kwargs.get("entropix_sampling", False) - annealed_sampling = sampling_kwargs.get("annealed_sampling", True) - - # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough - temperature = sampling_kwargs.pop("temperature", 0.0) - minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5) - # this really helps keep audio coherent so far - cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength) - cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) - start_noise = sampling_kwargs.get("denoise_start", 0.0) - end_noise = sampling_kwargs.get("denoise_end", 1.0) - remasking = sampling_kwargs.get("remasking", True) - max_steps = math.floor(max_steps * (end_noise - start_noise)) - - # to specify the initial mask used - vc_list = sampling_kwargs.pop("vc_list", None) - vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25) - vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25) - - len_list = [ clamp(l, min_length, max_length) for l in len_list ] - - # force set CFG because too low / no CFG causes issues - original_cfg_strength = cfg_strength - cfg_strength = max( cfg_strength, minimum_cfg_strength ) - - prefix_context = sampling_kwargs.get("prefix_context", None) - # fill with masked tokens (even though they get masked anyways) - resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ] - # fill scores - scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ] - - quant_levels = [ level for _ in range(batch_size) ] - null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] - null_prom = [ None for _ in range(batch_size) ] - - iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm) - for timestep in iterator: - # update previous list of tokens - prev_list = resps_list - # ramp down over time - annealing = 1.0 - timestep - # get noise level, per cosine scheduling - noise_p = math.cos( timestep * math.pi * 0.5 ) - # proportion of tokens to remask - remask_p = 1.0 / (max_steps * 2) if remasking else 0 - # pick the worst scoring tokens to mask off - masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] - # normal masking - # mask off inputs - resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ] - # boolean mask - is_masked = [ resps == self.mask_token for resps in resps_list ] - # timestep inputs - time_list = [ timestep for _ in range(batch_size) ] - - sampling_temperature = temperature * annealing if annealed_sampling else temperature - sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength - - input_resps_list = resps_list - - # setup inputs - inputs = super().inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=input_resps_list, - lang_list=lang_list, - tone_list=tone_list, - time_list=time_list, - quant_levels=quant_levels, - ) - output = super().forward( - inputs=inputs, - quant_levels=quant_levels, - ) - - logits = output.logits - if cfg_strength > 0: - null_inputs = super().inputs( - text_list=null_text, - proms_list=null_prom, - resps_list=input_resps_list, - lang_list=lang_list, - tone_list=tone_list, - time_list=time_list, - quant_levels=quant_levels, - ) - null_output = super().forward( - inputs=null_inputs, - quant_levels=quant_levels, - ) - - logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] ) - - l_scores = [] - l_resps_list = [] - # cringe hack because we're able to sample multiple levels at once - for l in range(self.n_resp_levels): - # sample with sampler settings - filtered_sampled = super().sample( - logits=[ logit[l] for logit in logits ], - prev_list=[ resp[..., l] for resp in prev_list ], - quant_levels=quant_levels, - - temperature=sampling_temperature, - **sampling_kwargs, - ) - - # retrieves unfiltered logits - unfiltered_sampled = super().sample( - logits=[ logit[l] for logit in logits ], - prev_list=[ resp[..., l] for resp in prev_list ], - quant_levels=quant_levels, - - temperature=0.0, - **sampling_kwargs, - ) - - # get sampled tokens - sampled_ids = filtered_sampled.ids - # keep unmasked tokens - l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]) - # get probability scores - l_scores.append([ - # conjugate to have worse scoring tokens picked for topk - 1.0 - - # only keep scores of tokens we are predicting (and ignore the tokens previously finalized) - torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) ) - # use unmodified logit scores for this, as it offers better stability - for scores, masked in zip( unfiltered_sampled.scores, is_masked ) - ]) - - resps_list = [] - scores = [] - - for batch_index in range(batch_size): - score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels - resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1) - - scores.append( score ) - resps_list.append( resp ) - - return resps_list - def forward_nar( self, task_list: list[Tensor] | None = None, - text_list: list[Tensor] | None = None, + phns_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, @@ -635,7 +454,7 @@ class AR_NAR(Base): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, disable_tqdm=False, use_lora=None, @@ -644,7 +463,7 @@ class AR_NAR(Base): # inference NAR level 0 if len_list is not None: resps_list = self.forward_nar_masked( - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, task_list=task_list, @@ -655,12 +474,12 @@ class AR_NAR(Base): ) # deduce batch_size - if text_list: + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: device = text_list[0].device batch_size = len(text_list) - elif raw_text_list: - device = raw_text_list[0].device - batch_size = len(raw_text_list) elif proms_list: device = proms_list[0].device batch_size = len(proms_list) @@ -701,7 +520,7 @@ class AR_NAR(Base): quant_levels = [ level for _ in range(batch_size) ] inputs = self.inputs( - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=prev_list, lang_list=lang_list, @@ -717,7 +536,7 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - text_list=null_text, + phns_list=null_text, proms_list=null_prom, resps_list=prev_list, lang_list=lang_list, @@ -748,8 +567,8 @@ class AR_NAR(Base): task_list: list[Tensor], + phns_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, @@ -761,12 +580,12 @@ class AR_NAR(Base): **sampling_kwargs, ): # deduce batch_size - if text_list: + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: device = text_list[0].device batch_size = len(text_list) - elif raw_text_list: - device = raw_text_list[0].device - batch_size = len(raw_text_list) elif proms_list: device = proms_list[0].device batch_size = len(proms_list) @@ -810,14 +629,14 @@ class AR_NAR(Base): inputs = self.inputs( task_list=task_list, - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, len_list=len_list, - raw_text_list=raw_text_list, + text_list=text_list, quant_levels=quant_levels, ) @@ -895,7 +714,7 @@ class AR_NAR(Base): if prefix_context is not None: prefix_text, prefix_resps, _ = prefix_context # to-do: check if we actually need to drop the middle "" - text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ] + phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, phns_list ) ] # feeding this into the NAR-len should automatically handle things sequence_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_resps ] @@ -906,13 +725,13 @@ class AR_NAR(Base): iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm) for n in iterator: if batch_size == 1 and task_list[0] in ["phn", "un-phn"]: - text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ] - raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ] + phns_list = [ sequence_list[i] if task in ["phn"] else phns_list[i] for i, task in enumerate(task_list) ] + text_list = [ sequence_list[i] if task in ["un-phn"] else text_list[i] for i, task in enumerate(task_list) ] else: - if raw_text_list is not None: - raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ] - else: + if text_list is not None: text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] + else: + phns_list = [ sequence_list[i] if task in text_task else phns_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] @@ -920,14 +739,14 @@ class AR_NAR(Base): inputs = self.inputs( task_list=task_list, - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, len_list=len_list, - raw_text_list=raw_text_list, + text_list=text_list, quant_levels=quant_levels, ) @@ -942,7 +761,7 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - text_list=null_text, + phns_list=null_text, proms_list=null_prom, resps_list=resps_list, lang_list=lang_list, @@ -960,7 +779,7 @@ class AR_NAR(Base): sampled = super().sample( logits=logits, - prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], + prev_list=[ resps_list[i] if task not in text_task else phns_list[i] for i, task in enumerate( task_list ) ], **(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}), ) @@ -981,7 +800,7 @@ class AR_NAR(Base): # first step, expand batch if batch_size == 1: batch_size = beam_width - text_list = text_list * beam_width + phns_list = phns_list * beam_width proms_list = proms_list * beam_width sequence_list = sequence_list * beam_width task_list = task_list * beam_width @@ -1046,175 +865,18 @@ class AR_NAR(Base): return sequence_list - def forward_ar_parallel( - self, - - task_list: list[Tensor], - - text_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, - proms_list: list[Tensor] | None = None, - resps_list: list[Tensor] | None = None, - lang_list: list[Tensor] | None = None, - tone_list: list[Tensor] | None = None, - len_list: list[Tensor] | None = None, - - disable_tqdm=False, - use_lora=None, - **sampling_kwargs, - ): - # deduce batch_size - if text_list: - device = text_list[0].device - batch_size = len(text_list) - elif raw_text_list: - device = raw_text_list[0].device - batch_size = len(raw_text_list) - elif proms_list: - device = proms_list[0].device - batch_size = len(proms_list) - elif resps_list: - device = resps_list[0].device - batch_size = len(resps_list) - - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) - - # convert AR specific args - sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) - - temperature = sampling_kwargs.get("temperature", 1.0) - cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) - cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) - min_temperature = sampling_kwargs.get("min_temperature", -1.0) - max_duration = sampling_kwargs.get("max_duration", 500) - beam_width = sampling_kwargs.get("beam_width", 0) - entropix_sampling = sampling_kwargs.get("entropix_sampling", False) - refine_on_stop = sampling_kwargs.get("refine_on_stop", False) - input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False) - layer_skip = sampling_kwargs.get("layer_skip", False) - prefix_silence = sampling_kwargs.get("prefix_silence", 0.0) - mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0) - mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0) - - start_slice = [ 0 for _ in range(batch_size) ] - sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ] - stopped = torch.zeros(batch_size, device=device).bool() - - audio_stop_token = self.stop_token - text_stop_token = 2 - - state = None - mirostat = [ - {"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} - ] * batch_size if mirostat_tau > 0.0 else None - - scores = [ 1.0 ] * beam_width - metrics = [] - - null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] - null_prom = [ None for _ in range(batch_size) ] - - # get next in sequence - iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm) - for n in iterator: - if raw_text_list is not None: - raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ] - else: - text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] - resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] - - quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] - - inputs = self.inputs( - task_list=task_list, - - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - raw_text_list=raw_text_list, - - quant_levels=quant_levels, - ) - - # to-do: find an elegant way to write this - output = super().forward( - inputs=inputs, - state=state, - #layer_skip_variables=sampling_layer_skip_variables, - output_attentions=entropix_sampling, - ) - - if cfg_strength > 0: - null_inputs = super().inputs( - text_list=null_text, - proms_list=null_prom, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - quant_levels=quant_levels, - ) - null_output = super().forward( - inputs=null_inputs, - quant_levels=quant_levels, - #layer_skip_variables=sampling_layer_skip_variables, - ) - logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] ) - - logits, state = output.logits, output.state - - l_resps_list = [ [] for _ in range(batch_size) ] - for l in range(self.n_resp_levels): - sampled = super().sample( - logits=[ logit[l] for logit in logits ], - prev_list=[ resp[..., l] for resp in resps_list ], - **(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}), - ) - - ids = sampled.ids - - # append tokens - for i, token in enumerate(ids): - if audio_stop_token in token: - stopped[i] = True - l_resps_list[i].append(token.to(device)) - - for i, l in enumerate(l_resps_list): - sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)]) - - # stop token found - # stopped |= r == stop_token - if stopped.all().item(): - iterator.close() - break - - for i, l in enumerate( sequence_list ): - index = (l == audio_stop_token).nonzero() - # kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment - try: - index = index[:, 0].min() - sequence_list[i] = sequence_list[i][:index] - except Exception as e: - pass - - return sequence_list - def forward( self, task_list: list[Tensor] | None = None, - text_list: list[Tensor] | None = None, + phns_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, training: bool | None = None, @@ -1224,12 +886,12 @@ class AR_NAR(Base): ): # deduce batch_size # deduce batch_size - if text_list: + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: device = text_list[0].device batch_size = len(text_list) - elif raw_text_list: - device = raw_text_list[0].device - batch_size = len(raw_text_list) elif proms_list: device = proms_list[0].device batch_size = len(proms_list) @@ -1238,7 +900,7 @@ class AR_NAR(Base): batch_size = len(resps_list) # implicitly set for training - if training is None and text_list is not None and resps_list is not None: + if training is None and phns_list is not None and resps_list is not None: n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) @@ -1249,118 +911,47 @@ class AR_NAR(Base): return self.forward_train( task_list=task_list, - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, len_list=len_list, - raw_text_list=raw_text_list, + text_list=text_list, ) # is NAR - if (len_list is not None or resps_list is not None) and text_list is not None: - if self.version >= 7: - return self.forward_nar_masked_parallel( - task_list=task_list, - - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - raw_text_list=raw_text_list, - - disable_tqdm=disable_tqdm, - use_lora=use_lora, - **sampling_kwargs, - ) - - # NAR demasking for all levels - """ - resps_lists = [ None for _ in range(batch_size) ] - for level in range(self.n_resp_levels): - resp_list = self.forward_nar_masked( - task_list=task_list, - - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - raw_text_list=raw_text_list, - - disable_tqdm=disable_tqdm, - use_lora=use_lora, - quant_levels=[ level for _ in range(batch_size) ], - **sampling_kwargs, - ) - - for batch_index, resp in enumerate(resp_list): - if resps_lists[batch_index] is None: - resps_lists[batch_index] = [] - - resps_lists[batch_index].append( resp ) - - for batch_index, resps in enumerate(resps_lists): - resps_lists[batch_index] = torch.stack( resps, dim=-1 ) - - return resps_lists - """ - + if (len_list is not None or resps_list is not None) and phns_list is not None: return self.forward_nar( task_list=task_list, - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, len_list=len_list, - raw_text_list=raw_text_list, + text_list=text_list, disable_tqdm=disable_tqdm, use_lora=use_lora, **sampling_kwargs, ) - if self.version >= 7: - if task_list is None or task_list[0] != "len": - return self.forward_ar_parallel( - task_list=task_list, - - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - raw_text_list=raw_text_list, - - disable_tqdm=disable_tqdm, - use_lora=use_lora, - **sampling_kwargs, - ) - # is AR return self.forward_ar( task_list=task_list, - text_list=text_list, + phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, len_list=len_list, - raw_text_list=raw_text_list, + text_list=text_list, disable_tqdm=disable_tqdm, use_lora=use_lora, @@ -1402,12 +993,11 @@ def example_usage(): text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}") batch_size = cfg.hyperparameters.batch_size - text_list = [ text ] * batch_size + phns_list = [ text ] * batch_size proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size kwargs = { - 'n_text_tokens': cfg.model.text_tokens, 'n_audio_tokens': cfg.model.audio_tokens, 'd_model': 1024, # 256, # 1024, # 1536 @@ -1545,7 +1135,7 @@ def example_usage(): def sample_data(t=None): if isinstance(t, list): tasks = t - texts = [ text_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ] + texts = [ phns_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ] proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ] resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ] @@ -1559,7 +1149,7 @@ def example_usage(): for i in range(batch_size): task = random.choice(available_tasks) if t is None else t - text = text_list[i].to(cfg.device) + text = phns_list[i].to(cfg.device) prom = proms_list[i].to(cfg.device) resp = resps_list[i].to(cfg.device) @@ -1580,16 +1170,16 @@ def example_usage(): def sample( name, steps=500, task=None ): engine.eval() - text_list, proms_list, resp_list, task_list = sample_data( task ) + phns_list, proms_list, resp_list, task_list = sample_data( task ) if task == "tts-nar": - len_list = engine( text_list=text_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 ) + len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 ) len_list = [ r.shape[0] for r in resp_list ] - resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list ) + resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list ) else: - resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) + resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1: - resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) + resps_list = engine( phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) for i, o in enumerate(resps_list): print( o.shape, o ) @@ -1604,7 +1194,7 @@ def example_usage(): texts, proms, resps, tasks = sample_data() stats = {"step": i} - stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) + stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py new file mode 100644 index 0000000..fa049fa --- /dev/null +++ b/vall_e/models/ar_nar_v2.py @@ -0,0 +1,1069 @@ +""" +# an AR + NAR model that handles: +* inferencing the primary RVQ level in an autoregressive manner (AR) +* inferencing the remaining RVQ levels in parallel (NAR) + +This model can fully handle being trained as a unified model (AR + NAR) or separate models (AR | NAR). +It's recommended to train as a unified model, then "distill" knowledge of each tasks separately, just in case. +""" +from .base_v2 import Base_V2, list_to_tensor, Categorical +from ..config import cfg + +import torch +from torch.nn.utils.rnn import pad_sequence + +import random +import math +import time +from einops import rearrange +from torch import Tensor +from tqdm import trange, tqdm + +import logging + +_logger = logging.getLogger(__name__) + +from ..emb.qnt import trim, get_silence +from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs + +from .lora import enable_lora +from ..samplers import cfg_logits + +text_task = [ "stt", "phn", "un-phn" ] + +class AR_NAR_V2(Base_V2): + # yikes + def forward_super(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + # parse inputs for training + # a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training + def forward_train( + self, + task_list: list[Tensor] | None = None, + + phns_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, + ): + # deduce batch_size + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: + device = text_list[0].device + batch_size = len(text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: + device = resps_list[0].device + batch_size = len(resps_list) + + # specifies how to sample probabilities of which RVQ levels to train against + rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" + # determines which RVQ level to target per batch + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] + # rate to perform token dropout errors + token_dropout_error = self.config.experimental.token_dropout_error + # RVQ levels to apply token dropout on + token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + # RVQ levels to apply masking training on + masking_train_rvq_levels = [0,self.n_resp_levels] # self.config.experimental.masking_train_rvq_levels + + if cfg.audio_backend == "nemo": + rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ] + + # CFG + cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 + cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 + cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0 + use_raw_text_p = self.config.experimental.use_raw_text_p if self.config is not None else 0.0 + # rate to train RVQ level AR-ly or NAR-ly + masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5 + masking_ratio = self.config.experimental.masking_ratio if self.config is not None else "random" + # force set mask training + if "len" not in self.capabilities: + masking_train_p = 0.0 + elif "ar" not in self.capabilities: + masking_train_p = 1.0 + # implicitly set it to all levels + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, self.resp_levels - 1] + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, 0] + + # allow passing a specific distribution of RVQ levels + rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] + if not rvq_levels_p: + lo, hi = quant_level_range[0], quant_level_range[1] + 1 + # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + if rvq_levels_p == "equal": + rvq_levels_p = [ i for i in range( lo, hi ) ] + else: + # yuck + rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + + # input RVQ levels + quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] + # timestep levels (for TTS NAR) + timesteps = [ None for _ in range(batch_size) ] + + for i, task in enumerate( task_list ): + lo, hi = masking_train_rvq_levels[0], masking_train_rvq_levels[1] + if task in text_task: + quant_levels[i] = 0 # self.n_resp_levels - 1 + elif lo <= quant_levels[i] and quant_levels[i] <= hi and random.random() < masking_train_p: + # to-do: prioritize lower timesteps over later timesteps + # ...except that the masking rate is still tied to the cosine scheduling, which does this already + #r = random.random() + #p = math.acos(r) / (math.pi * 0.5) + #timesteps[i] = 1.0 - clamp(p, 0.0, 1.0) + timesteps[i] = random.random() + + # instead make it between [0.2, 0.8] + if masking_ratio == "rand": + timesteps[i] = (timesteps[i] * 0.6) + 0.2 + + # tensor to cat for RVQ level 0 + text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16) + text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) + audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16) + + # final validations and stuff + for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): + # cap quant_level if it exceeds its corresponding resp/prom + # this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9 + if quant_level >= resps.shape[-1]: + quant_levels[i] = resps.shape[-1] - 1 + + # proms could be a Tensor, list[Tensor], or None + if isinstance( proms, torch.Tensor ): + if quant_level >= proms.shape[-1]: + quant_levels[i] = proms.shape[-1] - 1 + + elif isinstance( proms, list ): + for j, prom in enumerate( proms ): + if not isinstance( prom, torch.Tensor ): + continue + if quant_level >= prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 + + # apply token dropout error compensation + """ + if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): + steps = resps.shape[0] + for l in range( quant_level ): + for t in range( steps ): + token = resps[t, l].item() + + if random.random() < token_dropout_error: + offset = 1 * ( 1 if random.random() < 0.5 else -1 ) + resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + """ + + # only apply stop token for RVQ level 0 + if timesteps[i] is None or (self.predict_causally): + # append stop tokens for AR + if task not in text_task: + resps_list[i] = torch.cat([ resps, audio_stop_sequence.repeat((1, resps.shape[-1])) ]) + + if task == "len": + quant_levels[i] = 0 + + # apply CFG (should probably only apply to NAR quant level 0) + if task not in text_task + ["len"]: + drop_text = False + drop_audio = False + swap_text = False + + if random.random() < cfg_prom_dropout_p: + drop_audio = True + + if random.random() < cfg_cond_dropout_p: + drop_audio = True + drop_text = True + + if random.random() < use_raw_text_p and text_list[i] is not None: + swap_text = True + + if drop_text: + phns_list[i] = text_start_stop_sequence + + if drop_audio: + proms_list[i] = None + + if swap_text and not drop_text: + phns_list[i] = None + + inputs = self.inputs( + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + task_list=task_list, + text_list=text_list, + time_list=timesteps, + + quant_levels=quant_levels, + ) + + return super().forward( + inputs=inputs, + quant_levels=quant_levels, + ) + + # handles doing demasking inferencing in parallel to inference all tokens + # it works if the underlying model is trained properly (which is a pain) + def forward_nar_masked( + self, + + task_list: list[Tensor] | None = None, + + phns_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, + + disable_tqdm=False, + use_lora=None, + **sampling_kwargs, + ): + device = phns_list[0].device + batch_size = len(phns_list) + + level = 0 + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) + + # convert (N)AR specific args + sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) + + min_length = sampling_kwargs.pop("min_duration", 1) + max_length = sampling_kwargs.pop("max_duration", 500) + max_steps = sampling_kwargs.get("max_steps", 25) + refine_on_stop = sampling_kwargs.get("refine_on_stop", False) + entropix_sampling = sampling_kwargs.get("entropix_sampling", False) + annealed_sampling = sampling_kwargs.get("annealed_sampling", True) + + # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough + temperature = sampling_kwargs.pop("temperature", 0.0) + minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5) + # this really helps keep audio coherent so far + cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) + start_noise = sampling_kwargs.get("denoise_start", 0.0) + end_noise = sampling_kwargs.get("denoise_end", 1.0) + remasking = sampling_kwargs.get("remasking", True) + max_steps = math.floor(max_steps * (end_noise - start_noise)) + + # to specify the initial mask used + vc_list = sampling_kwargs.pop("vc_list", None) + vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25) + vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25) + + len_list = [ clamp(l, min_length, max_length) for l in len_list ] + + # force set CFG because too low / no CFG causes issues + original_cfg_strength = cfg_strength + cfg_strength = max( cfg_strength, minimum_cfg_strength ) + + prefix_context = sampling_kwargs.get("prefix_context", None) + # fill with masked tokens (even though they get masked anyways) + resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ] + # fill scores + scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ] + + quant_levels = [ level for _ in range(batch_size) ] + null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] + null_prom = [ None for _ in range(batch_size) ] + + iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm) + for timestep in iterator: + # update previous list of tokens + prev_list = resps_list + # ramp down over time + annealing = 1.0 - timestep + # get noise level, per cosine scheduling + noise_p = math.cos( timestep * math.pi * 0.5 ) + # proportion of tokens to remask + remask_p = 1.0 / (max_steps * 2) if remasking else 0 + # pick the worst scoring tokens to mask off + masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] + # normal masking + # mask off inputs + resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ] + # boolean mask + is_masked = [ resps == self.mask_token for resps in resps_list ] + # timestep inputs + time_list = [ timestep for _ in range(batch_size) ] + + sampling_temperature = temperature * annealing if annealed_sampling else temperature + sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength + + input_resps_list = resps_list + + # setup inputs + inputs = super().inputs( + phns_list=phns_list, + proms_list=proms_list, + resps_list=input_resps_list, + lang_list=lang_list, + tone_list=tone_list, + time_list=time_list, + quant_levels=quant_levels, + ) + output = super().forward( + inputs=inputs, + quant_levels=quant_levels, + ) + + logits = output.logits + if cfg_strength > 0: + null_inputs = super().inputs( + phns_list=null_text, + proms_list=null_prom, + resps_list=input_resps_list, + lang_list=lang_list, + tone_list=tone_list, + time_list=time_list, + quant_levels=quant_levels, + ) + null_output = super().forward( + inputs=null_inputs, + quant_levels=quant_levels, + ) + + logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] ) + + l_scores = [] + l_resps_list = [] + # cringe hack because we're able to sample multiple levels at once + for l in range(self.n_resp_levels): + # sample with sampler settings + filtered_sampled = super().sample( + logits=[ logit[l] for logit in logits ], + prev_list=[ resp[..., l] for resp in prev_list ], + quant_levels=quant_levels, + + temperature=sampling_temperature, + **sampling_kwargs, + ) + + # retrieves unfiltered logits + unfiltered_sampled = super().sample( + logits=[ logit[l] for logit in logits ], + prev_list=[ resp[..., l] for resp in prev_list ], + quant_levels=quant_levels, + + temperature=0.0, + **sampling_kwargs, + ) + + # get sampled tokens + sampled_ids = filtered_sampled.ids + # keep unmasked tokens + l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]) + # get probability scores + l_scores.append([ + # conjugate to have worse scoring tokens picked for topk + 1.0 - + # only keep scores of tokens we are predicting (and ignore the tokens previously finalized) + torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) ) + # use unmodified logit scores for this, as it offers better stability + for scores, masked in zip( unfiltered_sampled.scores, is_masked ) + ]) + + resps_list = [] + scores = [] + + for batch_index in range(batch_size): + score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels + resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1) + + scores.append( score ) + resps_list.append( resp ) + + return resps_list + + def forward_ar_len( + self, + + task_list: list[Tensor], + + phns_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + + disable_tqdm=False, + use_lora=None, + **sampling_kwargs, + ): + # deduce batch_size + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: + device = text_list[0].device + batch_size = len(text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: + device = resps_list[0].device + batch_size = len(resps_list) + + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) + + # convert AR specific args + sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) + + temperature = sampling_kwargs.get("temperature", 1.0) + cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) + min_temperature = sampling_kwargs.get("min_temperature", -1.0) + max_duration = sampling_kwargs.get("max_duration", 500) + beam_width = sampling_kwargs.get("beam_width", 0) + entropix_sampling = sampling_kwargs.get("entropix_sampling", False) + refine_on_stop = sampling_kwargs.get("refine_on_stop", False) + input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False) + layer_skip = sampling_kwargs.get("layer_skip", False) + prefix_silence = sampling_kwargs.get("prefix_silence", 0.0) + mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0) + mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0) + + # inference len + sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ] + stopped = torch.zeros(batch_size, device=device).bool() + + stop_token = 10 + task_list = [ "len" for _ in range(batch_size) ] + quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] + + iterator = trange(10, desc="AR", disable=disable_tqdm) + for n in iterator: + len_list = sequence_list + + inputs = self.inputs( + task_list=task_list, + + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + text_list=text_list, + + quant_levels=quant_levels, + ) + + output = super().forward( + inputs=inputs, + quant_levels=quant_levels, + ) + logits = output.logits + + r = [ logit[-1:].argmax(dim=1) for logit in logits ] + # sanitize + for i, token in enumerate(r): + if token > stop_token: + r[i][0] = stop_token + + # append tokens + for i, ri in enumerate(r): + if stop_token in ri: + stopped[i] = True + sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) + + # stop token found + stopped |= r == stop_token + if stopped.all().item(): + iterator.close() + break + + # convert tokens into int + return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ] + + def forward_ar( + self, + + task_list: list[Tensor], + + phns_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + + disable_tqdm=False, + use_lora=None, + **sampling_kwargs, + ): + # deduce batch_size + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: + device = text_list[0].device + batch_size = len(text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: + device = resps_list[0].device + batch_size = len(resps_list) + + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) + + # convert AR specific args + sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) + + temperature = sampling_kwargs.get("temperature", 1.0) + cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) + min_temperature = sampling_kwargs.get("min_temperature", -1.0) + max_duration = sampling_kwargs.get("max_duration", 500) + beam_width = sampling_kwargs.get("beam_width", 0) + entropix_sampling = sampling_kwargs.get("entropix_sampling", False) + refine_on_stop = sampling_kwargs.get("refine_on_stop", False) + input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False) + layer_skip = sampling_kwargs.get("layer_skip", False) + prefix_silence = sampling_kwargs.get("prefix_silence", 0.0) + mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0) + mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0) + + start_slice = [ 0 for _ in range(batch_size) ] + sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ] + stopped = torch.zeros(batch_size, device=device).bool() + + audio_stop_token = self.stop_token + text_stop_token = 2 + + state = None + mirostat = [ + {"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} + ] * batch_size if mirostat_tau > 0.0 else None + + scores = [ 1.0 ] * beam_width + metrics = [] + + null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] + null_prom = [ None for _ in range(batch_size) ] + + # get next in sequence + iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm) + for n in iterator: + if text_list is not None: + text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] + else: + phns_list = [ sequence_list[i] if task in text_task else phns_list[i] for i, task in enumerate(task_list) ] + resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + + quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] + + inputs = self.inputs( + task_list=task_list, + + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + text_list=text_list, + + quant_levels=quant_levels, + ) + + # to-do: find an elegant way to write this + output = super().forward( + inputs=inputs, + state=state, + #layer_skip_variables=sampling_layer_skip_variables, + output_attentions=entropix_sampling, + ) + + if cfg_strength > 0: + null_inputs = super().inputs( + phns_list=null_text, + proms_list=null_prom, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + null_output = super().forward( + inputs=null_inputs, + quant_levels=quant_levels, + #layer_skip_variables=sampling_layer_skip_variables, + ) + logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] ) + + logits, state = output.logits, output.state + + l_resps_list = [ [] for _ in range(batch_size) ] + for l in range(self.n_resp_levels): + sampled = super().sample( + logits=[ logit[l] for logit in logits ], + #prev_list=[ resp[..., l] for resp in resps_list ], + **(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}), + ) + + ids = sampled.ids + + # append tokens + for i, token in enumerate(ids): + if audio_stop_token in token: + stopped[i] = True + l_resps_list[i].append(token.to(device)) + + for i, l in enumerate(l_resps_list): + sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)]) + + # stop token found + # stopped |= r == stop_token + if stopped.all().item(): + iterator.close() + break + + for i, l in enumerate( sequence_list ): + index = (l == audio_stop_token).nonzero() + # kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment + try: + index = index[:, 0].min() + sequence_list[i] = sequence_list[i][:index] + except Exception as e: + pass + + return sequence_list + + def forward( + self, + task_list: list[Tensor] | None = None, + + phns_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, + + training: bool | None = None, + + disable_tqdm=False, + use_lora=None, + **sampling_kwargs, + ): + # deduce batch_size + # deduce batch_size + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: + device = text_list[0].device + batch_size = len(text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: + device = resps_list[0].device + batch_size = len(resps_list) + + # implicitly set for training + if training is None and phns_list is not None and resps_list is not None: + n_levels_set = {r.shape[-1] for r in resps_list} + n_levels = next(iter(n_levels_set)) + + training = n_levels == self.n_resp_levels + + # is training + if training: + return self.forward_train( + task_list=task_list, + + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + text_list=text_list, + ) + + # is NAR + if (len_list is not None or resps_list is not None) and phns_list is not None: + return self.forward_nar_masked( + task_list=task_list, + + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + text_list=text_list, + + disable_tqdm=disable_tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) + + # NAR demasking for all levels + """ + resps_lists = [ None for _ in range(batch_size) ] + for level in range(self.n_resp_levels): + resp_list = self.forward_nar_masked( + task_list=task_list, + + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + text_list=text_list, + + disable_tqdm=disable_tqdm, + use_lora=use_lora, + quant_levels=[ level for _ in range(batch_size) ], + **sampling_kwargs, + ) + + for batch_index, resp in enumerate(resp_list): + if resps_lists[batch_index] is None: + resps_lists[batch_index] = [] + + resps_lists[batch_index].append( resp ) + + for batch_index, resps in enumerate(resps_lists): + resps_lists[batch_index] = torch.stack( resps, dim=-1 ) + + return resps_lists + """ + + if task_list is not None and task_list[0] == "len": + return self.forward_ar_len( + task_list=task_list, + + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + text_list=text_list, + + disable_tqdm=disable_tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) + + # is AR + return self.forward_ar( + task_list=task_list, + + phns_list=phns_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + text_list=text_list, + + disable_tqdm=disable_tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) + + +def example_usage(): + #cfg.device = "cuda" + #cfg.trainer.backend = "local" + + from functools import partial + from einops import repeat + from tqdm import tqdm + + from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio + from ..data import _load_artifact + from ..engines import Engine, Engines + from ..utils import ml + from ..utils import setup_logging + + import numpy as np + import re + + # cfg.model.experimental.masking_train_p = 0.5 + cfg.hyperparameters.batch_size = 1 + cfg.hyperparameters.gradient_accumulation_steps = 1 + + setup_logging() + + def load_artifact( path ): + audio, metadata = _load_artifact(path, return_metadata=True) + + audio = audio.to(cfg.device) + text = torch.tensor( cfg.tokenizer.encode( metadata["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device) + + return text, audio + + text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}") + batch_size = cfg.hyperparameters.batch_size + + phns_list = [ text ] * batch_size + proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size + resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size + + kwargs = { + 'n_audio_tokens': cfg.model.audio_tokens, + + 'd_model': 1024, # 256, # 1024, # 1536 + 'n_heads': 16, # 4, # 16, # 24 + 'n_layers': 12, # 32 + 'n_experts': 1 if not cfg.model else cfg.model.experts, + + 'p_dropout': 0.1, + + 'l_padding': 8 if cfg.optimizations.fp8 else 0, + + 'config': cfg.model + } + + bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) + + available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else []) + + if cfg.model.experimental.masking_train_p == 0: + available_tasks = ["tts-ar"] + elif cfg.model.experimental.masking_train_p == 1: + available_tasks = ["tts-nar"] + + model = AR_NAR_V2(**kwargs).to(cfg.device) + steps = 250 // batch_size + + optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" + scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" + learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None + + params = { + "params": model.parameters() + } + if cfg.optimizations.dadaptation: + # do not combine the two + if scheduler == "schedulefree": + scheduler = "" + + learning_rate = 1.0 + + if optimizer == "prodigy": + if learning_rate is None: + learning_rate = 1.0 + + optimizer = ml.Prodigy + elif optimizer == "adagrad": + if learning_rate is None: + learning_rate = 1.0e-2 + + optimizer = ml.Adagrad + elif optimizer == "adamw": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.AdamW + elif optimizer == "sdg": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.SGD + elif optimizer == "apollo": + if learning_rate is None: + learning_rate = 0.01 + + optimizer = ml.Apollo + params["params"] = [ + {'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'} + ] + elif optimizer == "muon": + optimizer = ml.Muon + + muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] + adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + adamw_params += [ param for name, param in model.named_parameters() if not name.startswith('model.') ] + + params["params"] = [ + { "params": muon_params, "muon": True }, + { "params": adamw_params, "muon": False, "betas": (0.95, 0.95), "eps": 1e-8 }, + ] + elif optimizer == "cosmos": + optimizer = ml.COSMOS + else: + raise ValueError(f"Unrecognized optimizer: {optimizer}") + + _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") + + params["lr"] = learning_rate + optimizer = optimizer(**params) + + if scheduler == "schedulefree": + if isinstance(optimizer, ml.AdamW): + scheduler = ml.schedulefree.AdamWScheduleFree + elif isinstance(optimizer, ml.SGD): + scheduler = ml.schedulefree.SGDScheduleFree + else: + scheduler = None + + if scheduler is not None: + _logger.info(f"Scheduler: {scheduler}") + optimizer = scheduler( model.parameters(), lr = learning_rate ) + + if cfg.optimizations.replace and cfg.optimizations.linear: + model = ml.replace_linear( model ) + + if cfg.optimizations.replace and cfg.optimizations.embedding: + model = ml.replace_embedding( model ) + + """ + cfg.optimizations.model_offloading = { + "devices": ["cuda:0", "cpu"], + # "limits": [ 0.9, -1 ], + "assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]], + # "limits": [ 256 * (1024 ** 2), -1 ] + } + """ + + engine = Engine(model=model, optimizer=optimizer) + engines = Engines({"ar+nar": engine}) + engines.setup() + + """ + if cfg.optimizations.model_offloading: + model = ml.offload_model( model, policy=cfg.optimizations.model_offloading ) + """ + + """ + torch.save( { + 'module': model.state_dict() + }, f"./data/{cfg.model.arch_type}.pth" ) + """ + + _logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + @torch.no_grad() + def sample_data(t=None): + if isinstance(t, list): + tasks = t + texts = [ phns_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ] + proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ] + resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ] + + return texts, proms, resps, tasks + + texts = [] + proms = [] + resps = [] + tasks = [] + + for i in range(batch_size): + task = random.choice(available_tasks) if t is None else t + + text = phns_list[i].to(cfg.device) + prom = proms_list[i].to(cfg.device) + resp = resps_list[i].to(cfg.device) + + # do nothing + if task == "stt": + prom = [ task ] + else: + task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len" + + texts.append( text ) + proms.append( prom ) + resps.append( resp ) + tasks.append( task ) + + return texts, proms, resps, tasks + + @torch.inference_mode() + def sample( name, steps=500, task=None ): + engine.eval() + + phns_list, proms_list, resp_list, task_list = sample_data( task ) + + if task == "tts-nar": + # len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 ) + len_list = [ r.shape[0] for r in resp_list ] + resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list ) + else: + resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) + if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1: + resps_list = engine( phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) + + for i, o in enumerate(resps_list): + print( o.shape, o ) + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device) + + unload_model() + + def train(): + engine.train() + t = trange(steps) + for i in t: + texts, proms, resps, tasks = sample_data() + + stats = {"step": i} + stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) + stats |= {"grad_norm": engine.get_global_grad_norm()} + + tqdm.write(f"{stats}") + + """ + torch.save( { + 'module': model.state_dict() + }, f"./data/{cfg.model.arch_type}.pth" ) + """ + + task = available_tasks[0] + #sample("init", task=task) + + train() + + """ + if cfg.optimizations.compile: + model = ml.compile_model(model, backend=cfg.optimizations.compile) + """ + + for task in available_tasks: + sample("final", task=task) + + engines.quit() + +if __name__ == "__main__": + example_usage() \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 750a10c..ffd0328 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -9,6 +9,8 @@ This should handle all the "low" level things such as: Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model """ +# to-do: clean this whole mess up + import math import torch import torch.nn.functional as F @@ -50,22 +52,22 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern """ summed_embeddings_task = [ "stt" ] -special_tasks = [ "len", "stt", "phn", "un-phn" ] +special_tasks = [ "len", "stt", "phn", "text", "un-phn" ] non_tokened_names = ["task", "dropout_mask", "classifier_level"] task_outputs = { "tts": "resp", "ns": "resp", "sr": "resp", - "stt": "text", + "stt": "phn", "len": "len", - "phn": "text", - "un-phn": "raw_text", + "phn": "phn", + "un-phn": "text", } # yuck def _get_offsets(): return { - "text": (0, 256), + "phn": (0, 256), "quant_level": (256, 264), "lang": (264, 270), "task": (270, 279), @@ -126,15 +128,6 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ): def _interleave_sequence_flatten( input: list[torch.Tensor] ): return torch.concat( [ i.t() for i in input ] ).t().flatten() -# automagically parses a batch-list and returns it as a list -""" -class Embedding(ml.Embedding): - def forward(self, x_list: list[Tensor]) -> list[Tensor]: - if len(x_list) == 0: - return [] - return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) -""" - # Deprecated implementation class MultiEmbedding(nn.Module): def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False): @@ -334,92 +327,6 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ): x[..., level] = torch.where( dropout_mask, lhs, rhs ) return x -# naively embeds each level of a codebook, then merges the embeddings with a Linear -class AudioEncoder(nn.Module): - def __init__( - self, - n_tokens: int, - n_levels: int, - token_dim: int, - ): - super().__init__() - self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) - self.proj = nn.Linear(8 * token_dim, 1 * token_dim) - - def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: - # empty - if xi.shape[0] == 0: - dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0] - return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype) - if dropout_mask is not None: - xi = _dropout_codes( xi, dropout_mask, dropout_token ) - - # old way - # this probably is a tried and true good way to go about it - x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) - - # encode by interleaving - # this "works" but I imagine it being excessive and doesn't seem to help the model all that much - """ - x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) - x = x.view(x.shape[0], -1) - x = self.proj(x) - """ - - return x - -class AudioDecoder(nn.Module): - def __init__( - self, - d_model, - hidden_size, - vocab_size, - resp_levels, - ): - super().__init__() - - self.resp_levels = resp_levels - self.head = nn.Linear( d_model, vocab_size * resp_levels ) - - def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: - # prior way up-projected then down-projected, but that's silly - x = self.head( x ) - - # interleave by reshaping / permuting - # at least I hope this does it properly, it checks out against my OCR classifier - batch_size, seq_len, dim = x.shape - x = x.view( batch_size, seq_len, self.resp_levels, -1 ) - x = x.permute( 0, 2, 1, 3 ) - - return x -""" - -""" -# naively tries to extract multiple codebooks in parallel based on the last hidden state from the model -# this doesn't work well -""" -class ClassifiersParallel(nn.Module): - def __init__( - self, - n_levels: int, # codebook levels - n_tokens: int, # output token count - token_dim: int, # dimensionality of the embedding - bias: bool = False, - ): - super().__init__() - self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for _ in range(n_levels)]) - - def forward(self, xi: Tensor, stack: bool = True ) -> Tensor: - dtype = xi.dtype - device = xi.device - - xi = [ proj( xi ) for l, proj in enumerate(self.proj) ] - xi = torch.stack( xi ) - xi = xi.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification ) - - return xi -""" - class Metrics(nn.Module): def __init__( self, @@ -508,9 +415,9 @@ class Base(nn.Module): def __init__( self, - n_text_tokens: int = 256, + n_phn_tokens: int = 256, n_audio_tokens: int = 1024, - n_raw_text_tokens: int = 8575, + n_text_tokens: int = 8575, d_model: int = 512, d_ffn: int = 4, @@ -531,9 +438,9 @@ class Base(nn.Module): self.teaching = False self.config = config - self.n_text_tokens = n_text_tokens + self.n_phn_tokens = n_phn_tokens self.n_audio_tokens = n_audio_tokens - self.n_raw_text_tokens = n_raw_text_tokens + self.n_text_tokens = n_text_tokens self.d_model = d_model self.n_heads = n_heads @@ -595,36 +502,30 @@ class Base(nn.Module): n_tones = self.config.tones if self.config is not None else 1 # pure AR - if self.version < 7: - if "nar" not in self.capabilities: - n_resp_tokens = n_audio_tokens + 1 - l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels - l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )] - l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels - # NAR-len model - elif "len" in self.capabilities: - # +1 to include the stop or mask token - n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) - if "ar" in self.capabilities: - l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens] - l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1] - l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0'] - else: - l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] - # AR+NAR model - else: - # +1 to include the stop or mask token - n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) - l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] - l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - else: - n_resp_tokens = n_audio_tokens + 2 + if "nar" not in self.capabilities: + n_resp_tokens = n_audio_tokens + 1 l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels - l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )] - l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels + l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )] + l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels + # NAR-len model + elif "len" in self.capabilities: + # +1 to include the stop or mask token + n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) + if "ar" in self.capabilities: + l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens] + l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1] + l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0'] + else: + l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + # AR+NAR model + else: + # +1 to include the stop or mask token + n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) + l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1 @@ -632,7 +533,7 @@ class Base(nn.Module): # STT l_classifier_names += [ "stt" ] - l_classifier_tokens += [ n_text_tokens ] + l_classifier_tokens += [ n_phn_tokens ] # LEN if "len" in self.capabilities: @@ -641,8 +542,8 @@ class Base(nn.Module): # TEXT => PHN / PHN => TEXT if self.version >= 6: - l_classifier_tokens += [ n_raw_text_tokens ] - l_classifier_names = l_embedding_names + [ "raw_text" ] + l_classifier_tokens += [ n_text_tokens ] + l_classifier_names = l_embedding_names + [ "text" ] self.n_vocab = n_vocab self.unified_position_ids = unified_position_ids @@ -651,7 +552,7 @@ class Base(nn.Module): self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks - self.text_emb = Embedding(n_text_tokens, d_model) + self.text_emb = Embedding(n_phn_tokens, d_model) self.raw_text_emb = None self.langs_emb = None self.tones_emb = None @@ -680,7 +581,7 @@ class Base(nn.Module): levels=self.n_resp_levels if self.version > 3 else None, ) self.audio_emb = None - elif self.version < 7: + else: self.proms_emb = AudioEmbedding( [n_audio_tokens] * self.n_resp_levels, d_model, sums=audio_embedding_sums == "prom" or audio_embedding_sums == True, @@ -691,10 +592,6 @@ class Base(nn.Module): l_embedding_names=l_embedding_names, ) self.audio_emb = None - else: - self.proms_emb = None - self.resps_emb = None - self.audio_emb = None if self.version >= 3: self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None @@ -716,34 +613,7 @@ class Base(nn.Module): # experimental NAR-only mode self.len_emb = Embedding(11, d_model) if self.version >= 6: - self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) - - if self.version >= 7: - self.mask_token = self.stop_token + 1 - if monolithic_audio_encoder: - self.audio_emb = AudioEncoder( - n_tokens=n_audio_tokens + 2, # stop + masked token - n_levels=self.n_resp_levels, - token_dim=d_model, - ) - else: - self.proms_emb = AudioEncoder( - n_tokens=n_audio_tokens, - n_levels=self.n_resp_levels, - token_dim=d_model, - ) - self.resps_emb = AudioEncoder( - n_tokens=n_audio_tokens + 2, # stop + masked token - n_levels=self.n_resp_levels, - token_dim=d_model, - ) - - self.audio_decoder = AudioDecoder( - d_model, - d_model * 2, - (n_audio_tokens + 1), - self.n_resp_levels, - ) + self.raw_text_emb = Embedding(self.n_text_tokens, d_model) if attention_backend == "auto": attention_backend = "sdpa" @@ -1009,8 +879,8 @@ class Base(nn.Module): # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation def inputs( self, + phns_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None, - raw_text_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, @@ -1023,12 +893,12 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None ): - if text_list and text_list[0] is not None: + if phns_list and phns_list[0] is not None: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list and text_list[0] is not None: device = text_list[0].device batch_size = len(text_list) - elif raw_text_list and raw_text_list[0] is not None: - device = raw_text_list[0].device - batch_size = len(raw_text_list) elif proms_list and proms_list[0] is not None: device = proms_list[0].device batch_size = len(proms_list) @@ -1054,10 +924,10 @@ class Base(nn.Module): # prom /may/ include tokens inside to help guide things, per SpeechX if task_type in get_task_symmap() and task_type not in special_tasks: # insert the text prompt - if text_list is not None and text_list[i] is not None: + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) + elif text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) - elif raw_text_list is not None and raw_text_list[i] is not None: - inputs[i].append( ( "raw_text", raw_text_list[i] ) ) # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) @@ -1096,9 +966,6 @@ class Base(nn.Module): if resps_list is not None and resps_list[i] is not None: inputs[i].append( ( "resp", resps_list[i] ) ) - if self.version >= 7: - classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}" - inputs[i].append( ("classifier_level", classifier_level) ) # Audio length prediction task # Sequence: @@ -1108,10 +975,10 @@ class Base(nn.Module): raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.") # insert the text prompt - if text_list is not None and text_list[i] is not None: + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) + elif text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) - elif raw_text_list is not None and raw_text_list[i] is not None: - inputs[i].append( ( "raw_text", raw_text_list[i] ) ) # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) @@ -1147,42 +1014,42 @@ class Base(nn.Module): if self.rvq_l_emb is not None: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the output text prompt - if text_list is not None and text_list[i] is not None: - inputs[i].append( ( "text", text_list[i] ) ) + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) - inputs[i].append( ("classifier_level", "stt") ) + inputs[i].append( ("classifier_level", "phn") ) # Text phonemizing task - # Sequence: + # Sequence: elif task_type == "phn": # insert the text prompt - if raw_text_list is not None and raw_text_list[i] is not None: - inputs[i].append( ( "raw_text", raw_text_list[i] ) ) + if text_list is not None and text_list[i] is not None: + inputs[i].append( ( "text", text_list[i] ) ) # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) if self.rvq_l_emb is not None: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the text prompt - if text_list is not None and text_list[i] is not None: - inputs[i].append( ( "text", text_list[i] ) ) + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) - inputs[i].append( ("classifier_level", "stt") ) + inputs[i].append( ("classifier_level", "phn") ) # Text de-phonemizing task - # Sequence: + # Sequence: elif task_type == "un-phn": # insert the text prompt - if text_list is not None and text_list[i] is not None: - inputs[i].append( ( "text", text_list[i] ) ) + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) if self.rvq_l_emb is not None: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the text prompt - if raw_text_list is not None and raw_text_list[i] is not None: - inputs[i].append( ( "raw_text", raw_text_list[i] ) ) + if text_list is not None and text_list[i] is not None: + inputs[i].append( ( "text", text_list[i] ) ) - inputs[i].append( ("classifier_level", "raw_text") ) + inputs[i].append( ("classifier_level", "text") ) else: raise Exception(f'Unrecognized task: {task_type}') return inputs @@ -1240,17 +1107,11 @@ class Base(nn.Module): input if quant_level == 0 else input[:, :quant_level] ) - if self.version < 7: - return self.proms_emb( - input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], - quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level - offset = 0, - ) - - if self.audio_emb is not None: - return self.audio_emb( input ) - - return self.proms_emb( input ) + return self.proms_emb( + input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], + quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level + offset = 0, + ) # yuck token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0 @@ -1293,11 +1154,11 @@ class Base(nn.Module): # *maybe* inject a token for specifying task type task_type = input continue - elif name == "text": + elif name == "phn": embedding = self.text_emb( input ) device = embedding.device - elif name == "raw_text" and self.raw_text_emb is not None: + elif name == "text" and self.raw_text_emb is not None: embedding = self.raw_text_emb( input ) device = embedding.device @@ -1316,13 +1177,8 @@ class Base(nn.Module): elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": - if self.version >= 7: - if self.audio_emb is not None: - embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) - else: - embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) # if training NAR-len RVQ level 0 - elif dropout_mask is not None: + if dropout_mask is not None: embedding = self.resps_emb( # if masked use masked token, else original token torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ), @@ -1494,13 +1350,10 @@ class Base(nn.Module): return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens - if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums): + if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): return torch.full_like(input[..., 0], self.ignore_index) - - if self.version < 7: - return input if input.dim() == 1 else input[:, quant_level] - return input + return input if input.dim() == 1 else input[:, quant_level] def _calc_loss( logit, sequence, causal = True ): # filter tokens that exceed the vocab size @@ -1579,18 +1432,11 @@ class Base(nn.Module): token = token[..., 0] elif name == "resp": # mask found, apply it - if self.version < 7: - token = input if input.dim() == 1 else input[:, quant_level] - - # mask found, apply it - if dropout_mask is not None: - token = torch.where( dropout_mask, token, self.ignore_index ) - else: - token = input - - # mask found, apply it - if dropout_mask is not None: - token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True ) + token = input if input.dim() == 1 else input[:, quant_level] + + # mask found, apply it + if dropout_mask is not None: + token = torch.where( dropout_mask, token, self.ignore_index ) # not a special input, inject as-is else: token = input @@ -1785,7 +1631,7 @@ class Base(nn.Module): # needs to be done here as we still have our raw inputs position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None classifier_levels = self.get_input( inputs, name="classifier_level" ) - causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ] + causal_levels = [ "phn", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ] # right now limit to new versions because I need to retrain the model for noncausal masks... is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] @@ -1802,36 +1648,14 @@ class Base(nn.Module): logits = output.logits hidden_states = output.hidden_states - # split between the two logit tasks, as audio logits become expanded - if self.version >= 7: - logits = [ logit for logit in logits ] - - audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ] - - decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ] - classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ] - - if decoders_indices: - decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ]) - decoders_logits = self.audio_decoder( decoders_logits ) - for batch_index, logit in zip( decoders_indices, decoders_logits ): - logits[batch_index] = logit - - if classifiers_indices: - classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ] - classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ]) - classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels ) - for batch_index, logit in zip( classifiers_indices, classifiers_logits ): - logits[batch_index] = logit - else: - # output projection layer - # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways - if self.classifier is not None: - logits = self.classifier(logits) # * m - # to-do: piece-wise classification, now that there's a head for text - # although again, one single monolithic head would be preferable instead...... - elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_levels ) + # output projection layer + # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways + if self.classifier is not None: + logits = self.classifier(logits) # * m + # to-do: piece-wise classification, now that there's a head for text + # although again, one single monolithic head would be preferable instead...... + elif self.classifiers is not None: + logits = self.classifiers(logits, levels = classifier_levels ) # Remove padding logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] @@ -2206,7 +2030,7 @@ if __name__ == "__main__": seq = seq[:rvq_l, :] if rvq_l > 0 else seq sep_embd = embds["sep"](zero) - phn_embd = embds["text"](phn) + phn_embd = embds["phn"](phn) rvq_l_embd = embds["rvq_l"](rvq_l) lang_embd = embds["lang"](lang) prom_embd = torch.zeros(prom.shape[-1], n_embd, device=device, dtype=dtype) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py new file mode 100644 index 0000000..6d033d6 --- /dev/null +++ b/vall_e/models/base_v2.py @@ -0,0 +1,1147 @@ +import math +import torch +import torch.nn.functional as F +import random +import numpy as np +import re + +from time import perf_counter +from collections import namedtuple +from typing import Literal, overload, Optional, Tuple +from functools import partial +from einops import rearrange + +from torch import Tensor, einsum, nn +from torch.distributions import Categorical +from torch.nn.utils.rnn import pad_sequence +from torch.utils.checkpoint import checkpoint +from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision + +from .arch import * +from ..utils import ml, clamp +from ..samplers import * + +# yuck, kind of needed +from ..data import get_task_symmap + +import logging + +_logger = logging.getLogger(__name__) + +# these seem more elegant than a dict +Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states']) +Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy']) +LossStats = namedtuple('LossStats', ['loss', 'stats']) + +""" +from ..utils.pattern import DelayedPatternProvider, VALLEPattern +""" + +summed_embeddings_task = [ "stt" ] +special_tasks = [ "len", "stt", "phn", "text" ] +non_tokened_names = ["task", "dropout_mask", "classifier_level"] +task_outputs = { + "tts": "resp", + "ns": "resp", + "sr": "resp", + "stt": "phn", + "len": "len", + "phn": "phn", + "text": "text", +} + +def _dropout_mask( input, p ): + return (torch.rand(input.shape[0], device=input.device) < p) + +def _create_mask(l, device): + seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) + stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) + return (seq < stop).float() # (b t) + +def _join(x: tuple[Tensor], sep: Tensor): + ret = x[0] + for i in range(1, len(x)): + ret = torch.cat((ret, sep[None], x[i]), dim=0) + return ret + +def list_to_tensor(x_list: list[Tensor]): + l = list(map(len, x_list)) + x = pad_sequence(x_list, batch_first=True) + m = _create_mask(l, x_list[0].device) + + m = m.to(x).int() + return x, m + +def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ): + x = x.clone().detach() + levels = x.shape[-1] + for level in range( levels ): + lhs = dropout_token if not swapped else x[..., level] + rhs = x[..., level] if not swapped else dropout_token + x[..., level] = torch.where( dropout_mask, lhs, rhs ) + return x + +# naively embeds each level of a codebook, then merges the embeddings with a Linear +class AudioEncoder(nn.Module): + def __init__( + self, + n_tokens: int, + n_levels: int, + token_dim: int, + ): + super().__init__() + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + self.proj = nn.Linear(8 * token_dim, 1 * token_dim) + + def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: + # empty + if xi.shape[0] == 0: + dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0] + return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype) + if dropout_mask is not None: + xi = _dropout_codes( xi, dropout_mask, dropout_token ) + + # old way + # this probably is a tried and true good way to go about it + x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) + + # encode by interleaving + # this "works" but I imagine it being excessive and doesn't seem to help the model all that much + """ + x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) + x = x.view(x.shape[0], -1) + x = self.proj(x) + """ + + return x + +class AudioDecoder(nn.Module): + def __init__( + self, + d_model, + vocab_size, + resp_levels, + ): + super().__init__() + + self.resp_levels = resp_levels + self.head = nn.Linear( d_model, vocab_size * resp_levels ) + + def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: + # prior way up-projected then down-projected, but that's silly + x = self.head( x ) + + # interleave by reshaping / permuting + # at least I hope this does it properly, it checks out against my OCR classifier + batch_size, seq_len, dim = x.shape + x = x.view( batch_size, seq_len, self.resp_levels, -1 ) + x = x.permute( 0, 2, 1, 3 ) + + return x + +class AuxDecoder(nn.Module): + def __init__( + self, + d_model, + vocab_size, + name = None, + ): + super().__init__() + self.name = name + self.head = nn.Linear( d_model, vocab_size ) + + def forward(self, x: Tensor ) -> Tensor: + x = self.head( x ) + return x + +class Base_V2(nn.Module): + def loss_factor(self, k): + if self.config is None: + return 1.0 + return self.config.loss_factor(k) + + def _prune(self, l: Tensor, stop = None): + if stop is None: + stop = self.stop_token + + indices = (l == stop).nonzero() + + if len(indices) == 0: + return l + + return l[: indices.min().item()] + + def __init__( + self, + + n_phn_tokens: int = 256, + n_audio_tokens: int = 1024, + n_text_tokens: int = 8575, + + d_model: int = 512, + d_ffn: int = 4, + n_heads: int = 8, + n_layers: int = 12, + p_dropout: float = 0.1, + + n_experts: int = 1, + + l_padding: int = 0, + + training = True, + attention = None, + config = None, + ): + super().__init__() + + if not attention: + attention = config.attention if config is not None else "auto" + + attention_backend = attention + unified_position_ids = config.experimental.unified_position_ids if config is not None else True + noncausal_masks = config.experimental.noncausal_masks if config is not None else False + + max_position_embeddings = config.experimental.max_position_embeddings if config is not None else (75 * 60 * 5) + masking_ratio = config.experimental.masking_ratio if config is not None else False + ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False + + resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True + predict_causally = config.experimental.predict_causally if config is not None else False + monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False + + n_vocab = 256 + n_tasks = config.tasks if config is not None else 8 + n_langs = config.langs if config is not None else 2 + n_tones = config.tones if config is not None else 1 + + if attention_backend == "auto": + attention_backend = "sdpa" + + hf_attention = attention_backend + HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"] + + if attention_backend not in HF_ATTENTIONS: + hf_attention = None + if attention_backend not in AVAILABLE_ATTENTIONS: + raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") + + self.training = training + self.teaching = False + self.config = config + + self.n_phn_tokens = n_phn_tokens + self.n_audio_tokens = n_audio_tokens + self.n_text_tokens = n_text_tokens + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.n_experts = n_experts + + self.l_padding = l_padding + + self.ignore_index = -100 + + self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels + self.n_max_levels = self.config.max_levels if self.config else n_resp_levels + self.capabilities = self.config.capabilities if self.config else ["ar", "nar"] + self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True + + self.stop_token = self.n_audio_tokens + self.mask_token = self.stop_token + 1 + + self.causal = True + self.version = self.config.version if self.config is not None else 5 + self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0) + + self.arch_type = self.config.arch_type if self.config is not None else "llama" + + # check if requested arch is unavailable + if self.arch_type in ERROR_ARCHES: + raise ERROR_ARCHES[self.arch_type] + + # crunge + if self.config is not None and config.teacher: + self.teaching = True + self.training = False + + self.resp_parallel_training = resp_parallel_training + self.predict_causally = predict_causally + + self.unified_position_ids = unified_position_ids + self.inject_timestep_embedding = False # results in bad output + self.masking_ratio = masking_ratio + self.ignore_inputs_for_loss = ignore_inputs_for_loss + self.noncausal_masks = noncausal_masks + + self.sep = nn.Parameter(torch.randn(d_model)) + + self.phn_emb = ml.Embedding(n_phn_tokens, d_model) + self.text_emb = ml.Embedding(n_text_tokens, d_model) + self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None + self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None + self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None + self.len_emb = ml.Embedding(11, d_model) + + self.audio_emb = None + self.proms_emb = None + self.resps_emb = None + + if monolithic_audio_encoder: + self.audio_emb = AudioEncoder( + n_tokens=n_audio_tokens + 2, # stop + masked token + n_levels=self.n_resp_levels, + token_dim=d_model, + ) + else: + self.proms_emb = AudioEncoder( + n_tokens=n_audio_tokens, + n_levels=self.n_resp_levels, + token_dim=d_model, + ) + self.resps_emb = AudioEncoder( + n_tokens=n_audio_tokens + 2, # stop + masked token + n_levels=self.n_resp_levels, + token_dim=d_model, + ) + + self.audio_decoder = AudioDecoder( + d_model, + (n_audio_tokens + 1), + self.n_resp_levels, + ) + self.len_decoder = AuxDecoder( d_model, 11 ) + self.text_decoder = AuxDecoder( d_model, n_phn_tokens ) + self.raw_text_decoder = AuxDecoder( d_model, n_text_tokens ) + + # override any requested padding size + if attention_backend == "flash_attn_v100": + self.l_padding = 32 + elif attention_backend == "fused_attn": + self.l_padding = 128 + + if self.arch_type in ["llama"]: + self.model = LlamaModel_Adapted(LlamaConfig( + vocab_size=n_vocab, + hidden_size=d_model, + max_position_embeddings=max_position_embeddings, + intermediate_size=d_model*d_ffn, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + attention_dropout=p_dropout if training else 0.0, + num_key_value_heads=n_heads, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + attn_implementation=hf_attention, + #gradient_checkpointing=self.gradient_checkpointing, + )) + + self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) + + if self.gradient_checkpointing and not self.model.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + else: + raise RuntimeError(f'Unknown arch specified: {self.arch_type}') + + if hasattr( self.model, "embeddings" ): + del self.model.embeddings + + def _forward( + self, + inputs, + mask = None, + is_causal = None, + position_ids = None, + + state = None, + + output_attentions = False, + output_hidden_states = False, + ): + x = inputs + m = mask #.squeeze(-1).int() + + aux_loss = None + attentions = None + hidden_states = None + + # HF transformer derived model + if self.arch_type in ["llama"]: + kwargs = dict( + inputs_embeds=x, + attention_mask=m, + past_key_values=state, + position_ids=position_ids, + use_cache=False, # not self.training, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + is_causal=is_causal, + ) + + if self.n_experts > 1 and self.training: + kwargs["output_router_logits"] = True + + output = self.model(**kwargs) + x = output["last_hidden_state"] + + # to-do: figure out why KV caching doesn't work + #if not self.training: + if state is not None: + state = output["past_key_values"] + + if output_attentions: + attentions = output["attentions"] + + if output_hidden_states: + hidden_states = output["hidden_states"] + + if self.n_experts > 1 and self.training: + router_logits = output["router_logits"] + aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m ) + + # process it into a format that I like + if output_hidden_states: + # hidden_states is actually layers + 1, as hidden_states[0] == embedding........... + hidden_states = [ state for state in hidden_states[1:] ] + # apply normalization to these states (to-do: check if this matters) + # but skip the last state, as it already is normalized + hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ] + + return Logits(x, state, inputs, aux_loss, attentions, hidden_states) + + # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation + def inputs( + self, + phns_list: list[Tensor] | None = None, + text_list: list[Tensor] | None = None, + + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + task_list: list[str] | None = None, + time_list: list[Tensor] | None = None, + + quant_levels: int | list[int] | Tensor | None = None + ): + if phns_list and phns_list[0] is not None: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list and text_list[0] is not None: + device = text_list[0].device + batch_size = len(text_list) + elif proms_list and proms_list[0] is not None: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list and resps_list[0] is not None: + device = resps_list[0].device + batch_size = len(resps_list) + + inputs = [ [] for _ in range(batch_size) ] + for i in range(batch_size): + quant_level = quant_levels[i] if quant_levels is not None else 0 + task_type = task_list[i] if task_list is not None else "tts" + timestep = time_list[i] if time_list is not None else None + classifier_level = None + + # insert task type as a string + inputs[i].append( ( "task", task_type ) ) + + # to-do: maybe not split the below blocks up + # might be beneficial in the event I need to use a difference sequence, such as STT tasks + + # Base-line TTS task + # Sequence: + # prom /may/ include tokens inside to help guide things, per SpeechX + if task_type in get_task_symmap() and task_type not in special_tasks: + # insert the phn prompt + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) + elif text_list is not None and text_list[i] is not None: + inputs[i].append( ( "text", text_list[i] ) ) + # insert lang token if we're trained for it + if lang_list is not None and lang_list[i] is not None: + inputs[i].append( ( "lang", lang_list[i] ) ) + # insert input audio prompt + if proms_list is not None and proms_list[i] is not None: + inputs[i].append( ( "prom", proms_list[i] ) ) + # insert tone token if we're trained for it + if tone_list is not None and tone_list[i] is not None: + inputs[i].append( ( "tone", tone_list[i] ) ) + # insert timestep token + if timestep is not None: + p = self.masking_ratio + + # store dropout mask (if training, as this gets used later to mask the input embeddings if provided) + if self.training: + dropout_mask = _dropout_mask( resps_list[i], p ) + inputs[i].append( ("dropout_mask", dropout_mask ) ) + # insert the current output response + if resps_list is not None and resps_list[i] is not None: + inputs[i].append( ( "resp", resps_list[i] ) ) + + classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}" + + inputs[i].append( ("classifier_level", classifier_level) ) + # Audio length prediction task + # Sequence: + elif task_type == "len": + # throw an error so we don't silently train without this + if self.len_emb is None: + raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.") + + # insert the phn prompt + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) + elif text_list is not None and text_list[i] is not None: + inputs[i].append( ( "text", text_list[i] ) ) + # insert lang token if we're trained for it + if lang_list is not None and lang_list[i] is not None: + inputs[i].append( ( "lang", lang_list[i] ) ) + # insert input audio prompt + if proms_list is not None and proms_list[i] is not None: + inputs[i].append( ( "prom", proms_list[i] ) ) + # insert tone token if we're trained for it + if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None: + inputs[i].append( ( "tone", tone_list[i] ) ) + + # insert output length tokens (if it exists) + if len_list is not None and len_list[i] is not None: + inputs[i].append( ( "len", len_list[i] ) ) + # "encode" length to tokens for 0-9 + stop + elif resps_list is not None and resps_list[i] is not None: + # yes this could be encoded better + inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) ) + + inputs[i].append( ("classifier_level", "len") ) + # Speech-to-Text prediction task + # Sequence: + elif task_type == "stt": + # insert the input response + if resps_list is not None and resps_list[i] is not None: + inputs[i].append( ( "resp", resps_list[i] ) ) + # insert lang token if we're trained for it + if lang_list is not None and lang_list[i] is not None: + inputs[i].append( ( "lang", lang_list[i] ) ) + # insert the output phn prompt + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) + + inputs[i].append( ("classifier_level", "phn") ) + # Text phonemizing task + # Sequence: + elif task_type == "phn": + # insert the phn prompt + if text_list is not None and text_list[i] is not None: + inputs[i].append( ( "text", text_list[i] ) ) + # insert lang token if we're trained for it + if lang_list is not None and lang_list[i] is not None: + inputs[i].append( ( "lang", lang_list[i] ) ) + # insert the phn prompt + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) + + inputs[i].append( ("classifier_level", "phn") ) + # Text de-phonemizing task + # Sequence: + elif task_type == "text": + # insert the phn prompt + if phns_list is not None and phns_list[i] is not None: + inputs[i].append( ( "phn", phns_list[i] ) ) + # insert lang token if we're trained for it + if lang_list is not None and lang_list[i] is not None: + inputs[i].append( ( "lang", lang_list[i] ) ) + # insert the phn prompt + if text_list is not None and text_list[i] is not None: + inputs[i].append( ( "text", text_list[i] ) ) + + inputs[i].append( ("classifier_level", "text") ) + else: + raise Exception(f'Unrecognized task: {task_type}') + return inputs + + def inputs_to_embeddings( + self, + inputs: list, + quant_levels: int | list[int] | Tensor | None = None + ): + # handles tasks where the prompt has task tokens injected in the middle + def prompt_input_to_embedding( input, quant_level ): + if isinstance(input, str): + return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) ) + + if self.audio_emb is not None: + return self.audio_emb( input ) + + return self.proms_emb( input ) + + x_list = [] + for batch_index, batch_input in enumerate(inputs): + batch = [] + quant_level = quant_levels[batch_index] if quant_levels is not None else 0 + + task_type = "tts" + input_prom = None + classifier_level = None + dropout_mask = None + timestep = None + + # pre-iterate + for name, input in batch_input: + if name == "classifier_level": + classifier_level = input + elif name == "dropout_mask": + dropout_mask = input + elif name == "timestep": + timestep = input + + for name, input in batch_input: + # technically can provide a map for input_name => embedding, but some embedding requires additional processing + embedding = None + + # is already an embedding + if name == "task": + # noop + # *maybe* inject a token for specifying task type + task_type = input + continue + elif name == "phn": + embedding = self.phn_emb( input ) + + device = embedding.device + elif name == "text" and self.text_emb is not None: + embedding = self.text_emb( input ) + + device = embedding.device + elif name == "quant_level" and self.rvq_l_emb is not None: + embedding = self.rvq_l_emb( input ) + elif name == "lang" and self.langs_emb is not None: + embedding = self.langs_emb( input ) + elif name == "prom": + proms = [ input ] if isinstance(input, torch.Tensor) else input + embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] ) + elif name == "tone" and self.tones_emb is not None: + embedding = self.tones_emb( input ) + elif name == "resp": + if self.audio_emb is not None: + embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) + else: + embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) + elif name == "timestep" and self.time_emb is not None: + embedding = self.time_emb( input ) + elif name == "len" and self.len_emb is not None: + embedding = self.len_emb( input ) + else: + # should probably raise an exception so things aren't processed silently + continue + + batch.append(embedding) + + x_list.append( _join( batch, self.sep ) ) + + return x_list + + # get an attribute from a given input list + def get_input( + self, + inputs, + name, + at=None, + ): + find_all = at is None + res = [] if at is None else None + + for batch_index, batch_input in enumerate(inputs): + if not find_all and batch_index != at: + continue + + for n, input in batch_input: + if n != name: + continue + if not find_all: + return input + res.append( input ) + + return res + + # creates position ids from a given input list + # if not unified_position_ids, then each input segment will have its own sequence + def inputs_to_position_ids( + self, + inputs: list, + mask: Tensor, + ): + device = mask.device + + # shamelessly grabbed from modeling_llama.py + ids = mask.long().cumsum(-1) - 1 + ids.masked_fill_( mask == 0, 1 ) + + # there's a better way + if not self.unified_position_ids: + x_list = [] + + def get_input_token_length( name, input, task ): + # task token + if isinstance(input, str): + return 1 + + # list of tokens + if not isinstance(input, torch.Tensor): + return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + + # ending input will not have a separator later + return input.shape[0] + + for batch_index, batch_input in enumerate(inputs): + # pre-iterate + task = "tts" + for name, input in batch_input: + if name == "task": + task = input + + batch = torch.cat( [ + torch.tensor([*range(get_input_token_length(name, input, task) + (1 if name != task_outputs.get(task, name) else 0))], device=device, dtype=torch.int32) + for name, input in batch_input if name not in non_tokened_names + ] ) + + delta = ids[batch_index].shape[0] - batch.shape[0] + if delta > 0: + batch = torch.cat( [ batch, torch.tensor([1] * delta, device=device, dtype=torch.int32) ] ) + + x_list.append( batch ) + + ids = torch.stack( x_list ) + + return ids.to(device=device, dtype=torch.int32) + + def calc_loss( + self, + inputs: list, + logits, + + quant_levels: list[int] | None = None, + compute_hard_loss = True, + compute_acc = True, + ): + loss = {} + stats = {} + + device = logits[0].device + batch_size = len(logits) + classifier_levels = self.get_input( inputs, "classifier_level" ) + + # handles tasks where the prompt has task tokens injected in the middle + def prompt_input_to_token( input, quant_level ): + if isinstance(input, str): + return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) + + return input + + def _calc_loss( logit, sequence, causal = True ): + # filter tokens that exceed the vocab size + sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) + # drop if all tokens are ignored + if all(sequence == self.ignore_index): + return None, None + + # shift if causal + if causal or self.predict_causally: + l = self.causal_size + logit = logit[..., :-l, :] # shift the target so that token n... + sequence = sequence[..., l:] # ...predicts token n + 1 + + nll = None + metrics = None + if compute_hard_loss: + nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) + + if compute_acc: + accuracy_metric = MulticlassAccuracy( + logit.shape[-1], + top_k = min(logit.shape[0], 10), + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(logit.device) + metrics = accuracy_metric( logit, sequence ) + return nll, metrics + + for batch_index, batch in enumerate(inputs): + quant_level = quant_levels[batch_index] + target = [] + causal = True + task_type = "tts" + dropout_mask = None + classifier_level = None + output_len = 0 + + for name, input in batch: + if name == "task": + task_type = input + elif name == "dropout_mask": + dropout_mask = input + elif name == "classifier_level": + classifier_level = input + + # autoregressive, causal + if classifier_level.startswith("AR:"): + causal = True + # nonautoregressive, parallel + elif classifier_level.startswith("NAR:"): + causal = False + + it = 0 + for name, input in batch: + token = None + ignored = False + + # non-tokened tasks + if name in non_tokened_names: + continue + # prom can either be a tensor itself or a list of tensors and strings + if name == "prom": + # expand to list if not a list + proms = [ input ] if isinstance(input, torch.Tensor) else input + # iterate over the list to inject their tokens + token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) + + if logits[batch_index].dim() < 3 and token.dim() >= 2: + token = token[..., 0] + elif name == "resp": + token = input + + # mask found, apply it + if dropout_mask is not None: + token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True ) + # not a special input, inject as-is + else: + token = input + + if not isinstance(token, torch.Tensor): + continue + + if token.is_floating_point(): + ignored = True + + # grab range of our logits for later + seq_len = token.shape[0] + start, end = it, it+seq_len + it += seq_len + 1 # +1 to incorporate the separator + + # deduce if a name for a task is an input or output + if name != task_outputs.get(task_type, name): + if self.ignore_inputs_for_loss: + ignored = True + else: + output_len = seq_len + + if ignored: + # pruned + if self.config.loss_factors: + continue + # fill with ignored out tensor + token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16) + + # perform loss calculation on the individual piece + if self.config.loss_factors: + loss_factor = self.loss_factor(name) + + if loss_factor == 0.0: + continue + + if logits[batch_index].dim() < 3: + nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) + + if name == "resp": + name = f'{name}[{quant_level}]' + elif not self.resp_parallel_training: + # cringe way to deduce "requested" level + level = quant_level + for i in range( self.n_resp_levels ): + if classifier_level.endswith(f':{i}:{i}'): + level = i + break + if name == "resp": + name = f'{name}[{level}]' + sequence = token if token.dim() <= 1 else token[:, level] + nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) + else: + nlls = [] + accs = [] + + for level, logit in enumerate( logits[batch_index] ): + sequence = token if token.dim() <= 1 else token[:, level] + nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal ) + + if name == "resp": + if nll is not None: + if f'{name}[{level}].nll' not in loss: + loss[f'{name}[{level}].nll'] = [] + loss[f"{name}[{level}].nll"].append( nll * loss_factor ) + + if metrics is not None: + if f'{name}[{level}].acc' not in stats: + stats[f'{name}[{level}].acc'] = [] + stats[f"{name}[{level}].acc"].append( metrics ) + + nll = None + metrics = None + else: + if nll: + nlls.append( nll ) + if metrics: + accs.append( metrics ) + else: + if nlls: + nll = sum(nlls) / len(nlls) + if accs: + accs = sum(accs) / len(accs) + + if nll is not None: + if f'{name}.nll' not in loss: + loss[f'{name}.nll'] = [] + loss[f"{name}.nll"].append( nll * loss_factor ) + + if metrics is not None: + if f'{name}.acc' not in stats: + stats[f'{name}.acc'] = [] + stats[f"{name}.acc"].append( metrics ) + # add to list + else: + target.append( token ) + + # perofrm loss calculation on the entire sequence + if not self.config.loss_factors: + if logits[batch_index].dim() < 3: + sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) + nll, metrics = _calc_loss( logits[batch_index], sequence, causal ) + elif not self.resp_parallel_training: + # cringe way to deduce "requested" level + level = 0 + for i in range( self.n_resp_levels ): + if classifier_level.endswith(f':{i}:{i}'): + level = i + break + + sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] + sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) + nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal ) + else: + nlls = [] + accs = [] + + for level, logit in enumerate( logits[batch_index] ): + sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] + sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) + nll, metrics = _calc_loss( logit, sequence, causal ) + + if nll: + nlls.append( nll ) + if metrics: + accs.append( metrics ) + + if nlls: + nll = sum(nlls) / len(nlls) + if accs: + accs = sum(accs) / len(accs) + + if nll is not None: + if 'nll' not in loss: + loss['nll'] = [] + loss["nll"].append( nll ) + + if metrics is not None: + if 'acc' not in stats: + stats['acc'] = [] + stats["acc"].append( metrics ) + + # average + loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } + stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() } + + return LossStats(loss, stats) + + def forward( + self, + inputs: list, + + quant_levels: list[int] | None = None, + state: dict | list | None = None, + + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + # derive quant levels from inputs if not provided + if quant_levels is None: + quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ] + + # inputs don't have quant levels added, pure AR + if len(quant_levels) != len(inputs): + quant_levels = [ 0 for _ in range(len(inputs)) ] + + x_list = self.inputs_to_embeddings( inputs, quant_levels ) + + x, mask = list_to_tensor(x_list) + + training = self.training + teaching = self.teaching + device = x.device + batch_size = len(x_list) + + # pad our input and mask, but retain the original length by doing it after + if self.l_padding and x.shape[1] % self.l_padding != 0: + # pad input + shape = list(x.shape) + shape[1] = self.l_padding - shape[1] % self.l_padding + + padding = torch.zeros(shape, dtype=x.dtype, device=x.device) + x = torch.cat([x, padding], dim=1) + + # pad mask + shape[2] = 1 + padding = torch.zeros(shape[:2], dtype=x.dtype, device=x.device) + mask = torch.cat([mask, padding], dim=1) + + m = mask.unsqueeze(dim=-1) + + # needs to be done here as we still have our raw inputs + position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None + classifier_levels = self.get_input( inputs, name="classifier_level" ) + causal_levels = [ "len", "phn", "text" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ] + + # right now limit to new versions because I need to retrain the model for noncausal masks... + is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] + + output = self._forward( + inputs=x, + mask=mask, + state=state, + is_causal=is_causal, + position_ids=position_ids, + output_attentions = output_attentions, + ) + + logits = [ logit for logit in output.logits ] + hidden_states = output.hidden_states + + + grouped_logits = {} + + for batch_index in range( batch_size ): + classifier_level = classifier_levels[batch_index] + if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"): + classifier_level = "audio" + + if classifier_level not in ["audio", "phn", "text", "len"]: + continue + + if classifier_level not in grouped_logits: + grouped_logits[classifier_level] = [] + + grouped_logits[classifier_level].append(batch_index) + + for classifier_level, decoders_indices in grouped_logits.items(): + if classifier_level == "audio": + head = self.audio_decoder + elif classifier_level == "phn": + head = self.phn_decoder + elif classifier_level == "text": + head = self.text_decoder + elif classifier_level == "len": + head = self.len_decoder + + decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ]) + decoders_logits = head( decoders_logits ) + for batch_index, logit in zip( decoders_indices, decoders_logits ): + logits[batch_index] = logit + + # Remove padding + logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] + + if not training: + loss = None + stats = None + + self.loss = None + self.stats = None + + # compute loss if the target is given + else: + loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) + + # include any additional losses (for example: MoE router) + if output.loss is not None: + loss["aux_loss"] = output.loss + + self.loss = loss + self.stats = stats + + # rewrap, because we're modifying the logits here + return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states) + + def sample( + self, + logits: list[Tensor], # logit scores + prev_list: list[Tensor] | None = None, # logit scores + **sampling_kwargs, + ): + # yikes + temperature = sampling_kwargs.get("temperature", 1.0) + min_temperature = sampling_kwargs.get("min_temperature", -1.0) + top_k = sampling_kwargs.get("top_k", -100) + top_p = sampling_kwargs.get("top_p", 1.0) + min_p = sampling_kwargs.get("min_p", 0.0) + # repetition penalty parameters + repetition_penalty = sampling_kwargs.get("repetition_penalty", 1.0) + repetition_penalty_decay = sampling_kwargs.get("repetition_penalty_decay", 0.0) + # length penalty parameters + length_penalty = sampling_kwargs.get("length_penalty", 0.0) + # beam sampling parameters + beam_width = sampling_kwargs.get("beam_width", 0) + # mirostat sampling parameters + mirostat = sampling_kwargs.get("mirostat", None) + # DRY sampling parameters + dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0) + dry_base = sampling_kwargs.get("dry_base", 1.75) + dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2) + # + top_no = sampling_kwargs.get("top_no", 1.0) + # + attentions = sampling_kwargs.get("attentions", None) + + batch_size = len( logits ) + + if min_temperature < 0: + min_temperature = temperature + + scores = None + entropy = None + + if prev_list is not None: + seq_lens = map(len, prev_list) + logits = [ logit[-l:] for logit, l in zip(logits, seq_lens) ] + # (AR chunkwise) return the last chunkwise piece + elif self.causal: + seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ] + logits = [ logit[-self.causal_size:] for logit in logits ] + + # argmax instead + if temperature <= 0.0: + res = [ logit.argmax(dim=-1) for logit in logits ] + else: + res = [ Categorical(logits=logit / temperature).sample() for logit in logits ] + + # calculate token probabilities + scores = [ + [ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ] + for logit, tokens in zip(logits, res) + ] + + return Sampled(res, logits, scores, entropy) \ No newline at end of file diff --git a/vall_e/train.py b/vall_e/train.py index e6d5a6e..07f13c7 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -27,26 +27,27 @@ _logger = logging.getLogger(__name__) mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch, teacher=None): - engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) + engine.tokens_processed += sum([ text.shape[0] for text in batch["phns"] ]) engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ]) with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): - batch_size = len(batch["text"]) + batch_size = len(batch["phns"]) engine.current_batch_size = batch_size output = engine( - text_list=batch["text"], + phn_list=batch["phns"], proms_list=batch["proms"], resps_list=batch["resps"], lang_list=batch["lang"], tone_list=batch["tone"], task_list=batch["task"], - raw_text_list=batch["raw_text"], + text_list=batch["text"], training=True, ) # get soft targets from teacher + """ if teacher is not None: # extract inputs forwarded to model inputs = output.inputs @@ -99,6 +100,7 @@ def train_feeder(engine, batch, teacher=None): for k in engine.module.loss.keys(): engine.module.loss[k] *= (1.0 - A) engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size + """ losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") @@ -174,7 +176,7 @@ def run_eval(engines, eval_name, dl, args=None): for key in batch.keys(): batch[key] = batch[key][:cfg.evaluation.batch_size] - batch_size = len(batch["text"]) + batch_size = len(batch["phns"]) """ # to-do: eval for text tasks @@ -190,8 +192,8 @@ def run_eval(engines, eval_name, dl, args=None): # random prompts requested if args and args.eval_random_text_prompts and eval_name == "subtrain": - for i, _ in enumerate(batch["text"]): - batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device) + for i, _ in enumerate(batch["phns"]): + batch["phns"][i] = get_random_prompt(tokenized=True).to(device=cfg.device) batch["resps"][i] = None """ @@ -200,7 +202,7 @@ def run_eval(engines, eval_name, dl, args=None): engine = engines[name] base_kwargs = dict( - text_list=batch["text"], + phns_list=batch["phns"], proms_list=batch["proms"], lang_list=batch["lang"], task_list=batch["task"], @@ -242,22 +244,6 @@ def run_eval(engines, eval_name, dl, args=None): process( name, batch, resps_list ) - """ - # evaluate why it's so slow - if has_stt: - max_steps = max( [ text.shape[0] for text in batch["text"] ] ) - - kwargs["text_list"] = None - kwargs["task_list"] = [ "stt" for _ in range(batch_size) ] - kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ] - kwargs["resps_list"] = batch["resps"] - - text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0) - text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ] - - _logger.info(f"Validation Metrics (STT): {text_list}") - """ - stats = {k: sum(v) / len(v) for k, v in stats.items() if v} engines_stats = { f'{name}.{eval_name}': stats,