From ad7cfffc000ac9a7544893dae9a4a90acaeb4e01 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 13 Nov 2024 09:43:50 -0600 Subject: [PATCH] NAR-len RVQ-0 was being trained causally............. --- vall_e/models/base.py | 10 +++++++--- vall_e/utils/trainer.py | 2 +- vall_e/webui.py | 10 ++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index fd5f767..0205199 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1345,12 +1345,14 @@ class Base(nn.Module): if not self.config.loss_factors: target_list = [] task_list = [] + is_causal = [] for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] target = [] task_type = "tts" + causal = False dropout_mask = None for name, input in batch: if name == "dropout_mask": @@ -1364,13 +1366,14 @@ class Base(nn.Module): proms = [ input ] if isinstance(input, torch.Tensor) else input target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) ) elif name == "resp": + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_type in ["len", "stt"]) # mask found, apply it if dropout_mask is not None: # if mask use original token, else ignore + causal = False target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) ) elif self.interleave: target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) ) - elif task_type in summed_embeddings_task: target.append( torch.full_like(input[..., 0], self.ignore_index) ) else: @@ -1380,14 +1383,15 @@ class Base(nn.Module): elif name in ["text", "quant_level", "lang", "tone", "len"]: target.append( input ) + is_causal.append( causal ) target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) ) batch_size = len(target_list) - # modify only for the AR so it can properly behave like a transformer + # modify only causal sequences so it can properly behave like a transformer for i in range(batch_size): quant_level = quant_levels[i] task_name = task_list[i] - causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"]) + causal = is_causal[i] if causal: l = self.causal_size diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 58d1a6b..7f2862b 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -103,7 +103,7 @@ def _non_blocking_input(): def _make_infinite_epochs(dl): while True: - if dl_dataset.index() == 0: + if dl.dataset.index() == 0: _logger.info("New epoch starts.") # this number may jump from the dataloader sampling before the actual training step happens yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset)) diff --git a/vall_e/webui.py b/vall_e/webui.py index e922e25..41e8c3a 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -260,6 +260,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): gr.Info("Inferencing...") + # icky + modality = kwargs.get("modality") + if modality: + for name, engine in tts.engines.items(): + if modality == "AR+NAR": + engine.hyper_config.capabilities = ["ar", "nar"] + elif modality == "NAR-len": + engine.hyper_config.capabilities = ["nar", "len"] + sampling_kwargs = dict( max_steps=args.max_steps, max_levels=args.max_levels, @@ -455,6 +464,7 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") + layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="AR+NAR", choices=["AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.") with gr.Row(): layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")