actually fixed knowledge distillation because of errant -inf logits causing problems and needed to be filtered (and splitting text language / output audio language because it helps)

This commit is contained in:
mrq 2024-12-06 21:55:20 -06:00
parent 23d402bf01
commit 42fafbaaca
6 changed files with 66 additions and 16 deletions

View File

@ -12,6 +12,7 @@ def main():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text")
parser.add_argument("references", type=path_list, default=None)
parser.add_argument("--text-language", type=str, default=None)
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--modality", type=str, default="auto")
@ -114,6 +115,7 @@ def main():
output = tts.inference(
text=args.text,
references=args.references,
text_language=args.text_language,
language=args.language,
task=args.task,
modality=args.modality,

View File

@ -281,6 +281,7 @@ class ModelExperimentalSettings:
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
teacher_temperature: float = 1.0
# I really need to clean this up
@dataclass()

View File

@ -190,6 +190,7 @@ class TTS():
self,
text,
references,
text_language=None,
language="en",
task="tts",
modality="auto",
@ -202,6 +203,8 @@ class TTS():
use_lora=None,
**sampling_kwargs,
):
if not text_language:
text_language = language
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
wavs = []
@ -265,7 +268,7 @@ class TTS():
out_path = output_dir / f"{time.time()}.wav"
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
phns = self.encode_text( line, language=language )
phns = self.encode_text( line, language=text_language )
lang = self.encode_lang( language )
prom = to_device(prom, device=self.device, dtype=torch.int16)

View File

@ -443,6 +443,7 @@ class Base(nn.Module):
interleave = self.config.experimental.interleave if self.config is not None else False
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
teacher_temperature = self.config.experimental.teacher_temperature if self.config is not None else 0.5
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
@ -493,6 +494,7 @@ class Base(nn.Module):
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.teacher_alpha = teacher_alpha
self.teacher_temperature = teacher_temperature
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
"""
@ -1590,7 +1592,7 @@ class Base(nn.Module):
self.loss = None
self.stats = None
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels, compute_hard_loss=training, compute_acc=training )
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# compute it as an aux-loss
if self.layerskip:
@ -1614,30 +1616,64 @@ class Base(nn.Module):
self.training_steps += 1 # batch_size
# get soft targets from teacher
# it might be better to compute these once instead of per-engine, but realistically who is actually training multiple models
# required to do it in here because the batch is further processed within the model (because of per-model config)
if teacher is not None:
# grab the teacher's logits
with torch.no_grad():
teacher_output = teacher.forward_super(
inputs=inputs,
quant_levels=quant_levels,
)
soft_loss = [
F.kl_div(
F.log_softmax( student, dim=-1 ).unsqueeze(0),
F.softmax( teacher, dim=-1 ).unsqueeze(0),
reduction='batchmean'
)
for student, teacher in zip( logits, teacher_output.logits )
]
soft_loss = torch.stack([*soft_loss]).sum() / batch_size
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
# we could recreate the target sequence with the ignore indices put in, but that's agony
output_lens = [ 0 for _ in range(batch_size) ]
for batch_index, batch in enumerate(inputs):
task_type = "tts"
for name, input in batch:
if name == "task":
task_type = input
for name, input in batch:
if name == task_outputs.get(task_type, name):
output_lens[batch_index] = input.shape[0]
# KD hyperparameters
T = self.teacher_temperature
A = self.teacher_alpha
# create probability distributions (literature says to have the students already log'd but not the teacher)
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( logits, output_lens ) ]
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
# filter out logits that are / would inf
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
for batch_index, output_len in enumerate( output_lens ):
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
mask = mask_a | mask_b
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
#soft_losses = [ F.kl_div( student, teacher, reduction='mean' ) for student, teacher in zip( student_probs, teacher_probs ) ]
#soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
soft_loss = torch.stack([*soft_losses]).sum() * (T ** 2) / batch_size
"""
# flatten to a single sequence of token-probabilities
# but this shouldn't actually work because some logits might be (..., 1024) and some might be (..., 1025)
student_probs = torch.concat( student_probs, dim = 0 )
teacher_probs = torch.concat( teacher_probs, dim = 0 )
soft_loss = F.mse_loss( student_probs, teacher_probs ) * (T ** 2) / batch_size
"""
# mix if not nan
if not torch.isnan(soft_loss).any():
alpha = self.teacher_alpha
loss['kl'] = alpha * soft_loss
loss['kl'] = soft_loss * A
for k in loss.keys():
loss[k] *= (1.0 - alpha)
loss[k] *= (1.0 - A)
# include any additional losses (for example: MoE router)
if output.loss is not None:

View File

@ -104,6 +104,7 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
start = dl.dataset.index()
total = dl.dataset.batches()
manual_update = False
while True:
if dl.dataset.index() == 0:
@ -113,6 +114,10 @@ def _make_infinite_epochs(dl):
if start:
pbar.n = start
start = 0
manual_update = True
# for some reason this is required
if manual_update:
pbar.n += 1
yield from pbar

View File

@ -204,6 +204,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--modality", type=str, default=kwargs["modality"])
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--text-language", type=str, default=kwargs["text-language"])
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
@ -300,6 +301,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
wav, sr = tts.inference(
text=args.text,
language=args.language,
text_language=args.text_language,
task=args.task,
modality=args.modality.lower(),
references=args.references.split(";") if args.references is not None else [],
@ -445,7 +447,8 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).")
layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language (Output)", value="en", info="Target language/accent to output.")
layout["inference_tts"]["inputs"]["text-language"] = gr.Dropdown(choices=get_languages(), label="Language (Text)", value="en", info="Language the input text is in.")
with gr.Row():
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences")
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")