NAR-len RVQ-0 was being trained causally.............

This commit is contained in:
mrq 2024-11-13 09:43:50 -06:00
parent 976ee87f6f
commit ad7cfffc00
3 changed files with 18 additions and 4 deletions

View File

@ -1345,12 +1345,14 @@ class Base(nn.Module):
if not self.config.loss_factors:
target_list = []
task_list = []
is_causal = []
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
target = []
task_type = "tts"
causal = False
dropout_mask = None
for name, input in batch:
if name == "dropout_mask":
@ -1364,13 +1366,14 @@ class Base(nn.Module):
proms = [ input ] if isinstance(input, torch.Tensor) else input
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
elif name == "resp":
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_type in ["len", "stt"])
# mask found, apply it
if dropout_mask is not None:
# if mask use original token, else ignore
causal = False
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
elif self.interleave:
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
elif task_type in summed_embeddings_task:
target.append( torch.full_like(input[..., 0], self.ignore_index) )
else:
@ -1380,14 +1383,15 @@ class Base(nn.Module):
elif name in ["text", "quant_level", "lang", "tone", "len"]:
target.append( input )
is_causal.append( causal )
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
batch_size = len(target_list)
# modify only for the AR so it can properly behave like a transformer
# modify only causal sequences so it can properly behave like a transformer
for i in range(batch_size):
quant_level = quant_levels[i]
task_name = task_list[i]
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"])
causal = is_causal[i]
if causal:
l = self.causal_size

View File

@ -103,7 +103,7 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
while True:
if dl_dataset.index() == 0:
if dl.dataset.index() == 0:
_logger.info("New epoch starts.")
# this number may jump from the dataloader sampling before the actual training step happens
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset))

View File

@ -260,6 +260,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
gr.Info("Inferencing...")
# icky
modality = kwargs.get("modality")
if modality:
for name, engine in tts.engines.items():
if modality == "AR+NAR":
engine.hyper_config.capabilities = ["ar", "nar"]
elif modality == "NAR-len":
engine.hyper_config.capabilities = ["nar", "len"]
sampling_kwargs = dict(
max_steps=args.max_steps,
max_levels=args.max_levels,
@ -455,6 +464,7 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="AR+NAR", choices=["AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
with gr.Row():
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")