diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 2e9374d..33dd9bb 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -12,6 +12,7 @@ def main(): parser = argparse.ArgumentParser("VALL-E TTS") parser.add_argument("text") parser.add_argument("references", type=path_list, default=None) + parser.add_argument("--text-language", type=str, default=None) parser.add_argument("--language", type=str, default="en") parser.add_argument("--task", type=str, default="tts") parser.add_argument("--modality", type=str, default="auto") @@ -114,6 +115,7 @@ def main(): output = tts.inference( text=args.text, references=args.references, + text_language=args.text_language, language=args.language, task=args.task, modality=args.modality, diff --git a/vall_e/config.py b/vall_e/config.py index 836d29b..f972082 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -281,6 +281,7 @@ class ModelExperimentalSettings: layerskip_e_scale: float = 0.2 # early-exit loss scalar value teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation + teacher_temperature: float = 1.0 # I really need to clean this up @dataclass() diff --git a/vall_e/inference.py b/vall_e/inference.py index 29face0..7df9534 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -190,6 +190,7 @@ class TTS(): self, text, references, + text_language=None, language="en", task="tts", modality="auto", @@ -202,6 +203,8 @@ class TTS(): use_lora=None, **sampling_kwargs, ): + if not text_language: + text_language = language lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences")) wavs = [] @@ -265,7 +268,7 @@ class TTS(): out_path = output_dir / f"{time.time()}.wav" prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None - phns = self.encode_text( line, language=language ) + phns = self.encode_text( line, language=text_language ) lang = self.encode_lang( language ) prom = to_device(prom, device=self.device, dtype=torch.int16) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index ca67fee..7ae1b9a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -443,6 +443,7 @@ class Base(nn.Module): interleave = self.config.experimental.interleave if self.config is not None else False noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5 + teacher_temperature = self.config.experimental.teacher_temperature if self.config is not None else 0.5 masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False @@ -493,6 +494,7 @@ class Base(nn.Module): self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks self.teacher_alpha = teacher_alpha + self.teacher_temperature = teacher_temperature # use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends """ @@ -1590,7 +1592,7 @@ class Base(nn.Module): self.loss = None self.stats = None else: - loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels, compute_hard_loss=training, compute_acc=training ) + loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) # compute it as an aux-loss if self.layerskip: @@ -1614,30 +1616,64 @@ class Base(nn.Module): self.training_steps += 1 # batch_size # get soft targets from teacher - # it might be better to compute these once instead of per-engine, but realistically who is actually training multiple models + # required to do it in here because the batch is further processed within the model (because of per-model config) if teacher is not None: + # grab the teacher's logits with torch.no_grad(): teacher_output = teacher.forward_super( inputs=inputs, quant_levels=quant_levels, ) - soft_loss = [ - F.kl_div( - F.log_softmax( student, dim=-1 ).unsqueeze(0), - F.softmax( teacher, dim=-1 ).unsqueeze(0), - reduction='batchmean' - ) - for student, teacher in zip( logits, teacher_output.logits ) - ] - soft_loss = torch.stack([*soft_loss]).sum() / batch_size + # determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways) + # we could recreate the target sequence with the ignore indices put in, but that's agony + output_lens = [ 0 for _ in range(batch_size) ] + for batch_index, batch in enumerate(inputs): + task_type = "tts" + for name, input in batch: + if name == "task": + task_type = input + + for name, input in batch: + if name == task_outputs.get(task_type, name): + output_lens[batch_index] = input.shape[0] + + # KD hyperparameters + T = self.teacher_temperature + A = self.teacher_alpha + + # create probability distributions (literature says to have the students already log'd but not the teacher) + student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( logits, output_lens ) ] + teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ] + + # filter out logits that are / would inf + # this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier) + for batch_index, output_len in enumerate( output_lens ): + mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf + mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf + + mask = mask_a | mask_b + student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask ) + teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask ) + + #soft_losses = [ F.kl_div( student, teacher, reduction='mean' ) for student, teacher in zip( student_probs, teacher_probs ) ] + #soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ] + soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ] + soft_loss = torch.stack([*soft_losses]).sum() * (T ** 2) / batch_size + + """ + # flatten to a single sequence of token-probabilities + # but this shouldn't actually work because some logits might be (..., 1024) and some might be (..., 1025) + student_probs = torch.concat( student_probs, dim = 0 ) + teacher_probs = torch.concat( teacher_probs, dim = 0 ) + soft_loss = F.mse_loss( student_probs, teacher_probs ) * (T ** 2) / batch_size + """ # mix if not nan if not torch.isnan(soft_loss).any(): - alpha = self.teacher_alpha - loss['kl'] = alpha * soft_loss + loss['kl'] = soft_loss * A for k in loss.keys(): - loss[k] *= (1.0 - alpha) + loss[k] *= (1.0 - A) # include any additional losses (for example: MoE router) if output.loss is not None: diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 2e708d6..5cb2b10 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -104,6 +104,7 @@ def _non_blocking_input(): def _make_infinite_epochs(dl): start = dl.dataset.index() total = dl.dataset.batches() + manual_update = False while True: if dl.dataset.index() == 0: @@ -113,6 +114,10 @@ def _make_infinite_epochs(dl): if start: pbar.n = start start = 0 + manual_update = True + # for some reason this is required + if manual_update: + pbar.n += 1 yield from pbar diff --git a/vall_e/webui.py b/vall_e/webui.py index ca41755..e9402f2 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -204,6 +204,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--modality", type=str, default=kwargs["modality"]) parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--language", type=str, default=kwargs["language"]) + parser.add_argument("--text-language", type=str, default=kwargs["text-language"]) parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"]) parser.add_argument("--context-history", type=int, default=kwargs["context-history"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) @@ -300,6 +301,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): wav, sr = tts.inference( text=args.text, language=args.language, + text_language=args.text_language, task=args.task, modality=args.modality.lower(), references=args.references.split(";") if args.references is not None else [], @@ -445,7 +447,8 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).") layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).") - layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") + layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language (Output)", value="en", info="Target language/accent to output.") + layout["inference_tts"]["inputs"]["text-language"] = gr.Dropdown(choices=get_languages(), label="Language (Text)", value="en", info="Language the input text is in.") with gr.Row(): layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences") layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")