NAR-len RVQ-0 was being trained causally.............
This commit is contained in:
parent
976ee87f6f
commit
ad7cfffc00
|
@ -1345,12 +1345,14 @@ class Base(nn.Module):
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
target_list = []
|
target_list = []
|
||||||
task_list = []
|
task_list = []
|
||||||
|
is_causal = []
|
||||||
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
quant_level = quant_levels[batch_index]
|
quant_level = quant_levels[batch_index]
|
||||||
target = []
|
target = []
|
||||||
task_type = "tts"
|
task_type = "tts"
|
||||||
|
|
||||||
|
causal = False
|
||||||
dropout_mask = None
|
dropout_mask = None
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "dropout_mask":
|
if name == "dropout_mask":
|
||||||
|
@ -1364,13 +1366,14 @@ class Base(nn.Module):
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
|
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_type in ["len", "stt"])
|
||||||
# mask found, apply it
|
# mask found, apply it
|
||||||
if dropout_mask is not None:
|
if dropout_mask is not None:
|
||||||
# if mask use original token, else ignore
|
# if mask use original token, else ignore
|
||||||
|
causal = False
|
||||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||||
elif self.interleave:
|
elif self.interleave:
|
||||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||||
|
|
||||||
elif task_type in summed_embeddings_task:
|
elif task_type in summed_embeddings_task:
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
else:
|
else:
|
||||||
|
@ -1380,14 +1383,15 @@ class Base(nn.Module):
|
||||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||||
target.append( input )
|
target.append( input )
|
||||||
|
|
||||||
|
is_causal.append( causal )
|
||||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||||
|
|
||||||
batch_size = len(target_list)
|
batch_size = len(target_list)
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
# modify only causal sequences so it can properly behave like a transformer
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
quant_level = quant_levels[i]
|
quant_level = quant_levels[i]
|
||||||
task_name = task_list[i]
|
task_name = task_list[i]
|
||||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"])
|
causal = is_causal[i]
|
||||||
|
|
||||||
if causal:
|
if causal:
|
||||||
l = self.causal_size
|
l = self.causal_size
|
||||||
|
|
|
@ -103,7 +103,7 @@ def _non_blocking_input():
|
||||||
|
|
||||||
def _make_infinite_epochs(dl):
|
def _make_infinite_epochs(dl):
|
||||||
while True:
|
while True:
|
||||||
if dl_dataset.index() == 0:
|
if dl.dataset.index() == 0:
|
||||||
_logger.info("New epoch starts.")
|
_logger.info("New epoch starts.")
|
||||||
# this number may jump from the dataloader sampling before the actual training step happens
|
# this number may jump from the dataloader sampling before the actual training step happens
|
||||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset))
|
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset))
|
||||||
|
|
|
@ -260,6 +260,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
|
|
||||||
gr.Info("Inferencing...")
|
gr.Info("Inferencing...")
|
||||||
|
|
||||||
|
# icky
|
||||||
|
modality = kwargs.get("modality")
|
||||||
|
if modality:
|
||||||
|
for name, engine in tts.engines.items():
|
||||||
|
if modality == "AR+NAR":
|
||||||
|
engine.hyper_config.capabilities = ["ar", "nar"]
|
||||||
|
elif modality == "NAR-len":
|
||||||
|
engine.hyper_config.capabilities = ["nar", "len"]
|
||||||
|
|
||||||
sampling_kwargs = dict(
|
sampling_kwargs = dict(
|
||||||
max_steps=args.max_steps,
|
max_steps=args.max_steps,
|
||||||
max_levels=args.max_levels,
|
max_levels=args.max_levels,
|
||||||
|
@ -455,6 +464,7 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||||
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
||||||
|
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="AR+NAR", choices=["AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user