actually fixed knowledge distillation because of errant -inf logits causing problems and needed to be filtered (and splitting text language / output audio language because it helps)
This commit is contained in:
parent
23d402bf01
commit
42fafbaaca
|
@ -12,6 +12,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||||
parser.add_argument("text")
|
parser.add_argument("text")
|
||||||
parser.add_argument("references", type=path_list, default=None)
|
parser.add_argument("references", type=path_list, default=None)
|
||||||
|
parser.add_argument("--text-language", type=str, default=None)
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
parser.add_argument("--task", type=str, default="tts")
|
parser.add_argument("--task", type=str, default="tts")
|
||||||
parser.add_argument("--modality", type=str, default="auto")
|
parser.add_argument("--modality", type=str, default="auto")
|
||||||
|
@ -114,6 +115,7 @@ def main():
|
||||||
output = tts.inference(
|
output = tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
references=args.references,
|
references=args.references,
|
||||||
|
text_language=args.text_language,
|
||||||
language=args.language,
|
language=args.language,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
modality=args.modality,
|
modality=args.modality,
|
||||||
|
|
|
@ -281,6 +281,7 @@ class ModelExperimentalSettings:
|
||||||
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
||||||
|
|
||||||
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
||||||
|
teacher_temperature: float = 1.0
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
|
|
@ -190,6 +190,7 @@ class TTS():
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
references,
|
references,
|
||||||
|
text_language=None,
|
||||||
language="en",
|
language="en",
|
||||||
task="tts",
|
task="tts",
|
||||||
modality="auto",
|
modality="auto",
|
||||||
|
@ -202,6 +203,8 @@ class TTS():
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
|
if not text_language:
|
||||||
|
text_language = language
|
||||||
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
||||||
|
|
||||||
wavs = []
|
wavs = []
|
||||||
|
@ -265,7 +268,7 @@ class TTS():
|
||||||
out_path = output_dir / f"{time.time()}.wav"
|
out_path = output_dir / f"{time.time()}.wav"
|
||||||
|
|
||||||
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
|
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
|
||||||
phns = self.encode_text( line, language=language )
|
phns = self.encode_text( line, language=text_language )
|
||||||
lang = self.encode_lang( language )
|
lang = self.encode_lang( language )
|
||||||
|
|
||||||
prom = to_device(prom, device=self.device, dtype=torch.int16)
|
prom = to_device(prom, device=self.device, dtype=torch.int16)
|
||||||
|
|
|
@ -443,6 +443,7 @@ class Base(nn.Module):
|
||||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||||
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||||
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
|
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
|
||||||
|
teacher_temperature = self.config.experimental.teacher_temperature if self.config is not None else 0.5
|
||||||
|
|
||||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||||
|
@ -493,6 +494,7 @@ class Base(nn.Module):
|
||||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||||
self.noncausal_masks = noncausal_masks
|
self.noncausal_masks = noncausal_masks
|
||||||
self.teacher_alpha = teacher_alpha
|
self.teacher_alpha = teacher_alpha
|
||||||
|
self.teacher_temperature = teacher_temperature
|
||||||
|
|
||||||
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||||
"""
|
"""
|
||||||
|
@ -1590,7 +1592,7 @@ class Base(nn.Module):
|
||||||
self.loss = None
|
self.loss = None
|
||||||
self.stats = None
|
self.stats = None
|
||||||
else:
|
else:
|
||||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels, compute_hard_loss=training, compute_acc=training )
|
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||||
|
|
||||||
# compute it as an aux-loss
|
# compute it as an aux-loss
|
||||||
if self.layerskip:
|
if self.layerskip:
|
||||||
|
@ -1614,30 +1616,64 @@ class Base(nn.Module):
|
||||||
self.training_steps += 1 # batch_size
|
self.training_steps += 1 # batch_size
|
||||||
|
|
||||||
# get soft targets from teacher
|
# get soft targets from teacher
|
||||||
# it might be better to compute these once instead of per-engine, but realistically who is actually training multiple models
|
# required to do it in here because the batch is further processed within the model (because of per-model config)
|
||||||
if teacher is not None:
|
if teacher is not None:
|
||||||
|
# grab the teacher's logits
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
teacher_output = teacher.forward_super(
|
teacher_output = teacher.forward_super(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
soft_loss = [
|
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
|
||||||
F.kl_div(
|
# we could recreate the target sequence with the ignore indices put in, but that's agony
|
||||||
F.log_softmax( student, dim=-1 ).unsqueeze(0),
|
output_lens = [ 0 for _ in range(batch_size) ]
|
||||||
F.softmax( teacher, dim=-1 ).unsqueeze(0),
|
for batch_index, batch in enumerate(inputs):
|
||||||
reduction='batchmean'
|
task_type = "tts"
|
||||||
)
|
for name, input in batch:
|
||||||
for student, teacher in zip( logits, teacher_output.logits )
|
if name == "task":
|
||||||
]
|
task_type = input
|
||||||
soft_loss = torch.stack([*soft_loss]).sum() / batch_size
|
|
||||||
|
for name, input in batch:
|
||||||
|
if name == task_outputs.get(task_type, name):
|
||||||
|
output_lens[batch_index] = input.shape[0]
|
||||||
|
|
||||||
|
# KD hyperparameters
|
||||||
|
T = self.teacher_temperature
|
||||||
|
A = self.teacher_alpha
|
||||||
|
|
||||||
|
# create probability distributions (literature says to have the students already log'd but not the teacher)
|
||||||
|
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( logits, output_lens ) ]
|
||||||
|
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
|
||||||
|
|
||||||
|
# filter out logits that are / would inf
|
||||||
|
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
|
||||||
|
for batch_index, output_len in enumerate( output_lens ):
|
||||||
|
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
|
||||||
|
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
|
||||||
|
|
||||||
|
mask = mask_a | mask_b
|
||||||
|
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
|
||||||
|
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
|
||||||
|
|
||||||
|
#soft_losses = [ F.kl_div( student, teacher, reduction='mean' ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||||
|
#soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||||
|
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||||
|
soft_loss = torch.stack([*soft_losses]).sum() * (T ** 2) / batch_size
|
||||||
|
|
||||||
|
"""
|
||||||
|
# flatten to a single sequence of token-probabilities
|
||||||
|
# but this shouldn't actually work because some logits might be (..., 1024) and some might be (..., 1025)
|
||||||
|
student_probs = torch.concat( student_probs, dim = 0 )
|
||||||
|
teacher_probs = torch.concat( teacher_probs, dim = 0 )
|
||||||
|
soft_loss = F.mse_loss( student_probs, teacher_probs ) * (T ** 2) / batch_size
|
||||||
|
"""
|
||||||
|
|
||||||
# mix if not nan
|
# mix if not nan
|
||||||
if not torch.isnan(soft_loss).any():
|
if not torch.isnan(soft_loss).any():
|
||||||
alpha = self.teacher_alpha
|
loss['kl'] = soft_loss * A
|
||||||
loss['kl'] = alpha * soft_loss
|
|
||||||
for k in loss.keys():
|
for k in loss.keys():
|
||||||
loss[k] *= (1.0 - alpha)
|
loss[k] *= (1.0 - A)
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if output.loss is not None:
|
if output.loss is not None:
|
||||||
|
|
|
@ -104,6 +104,7 @@ def _non_blocking_input():
|
||||||
def _make_infinite_epochs(dl):
|
def _make_infinite_epochs(dl):
|
||||||
start = dl.dataset.index()
|
start = dl.dataset.index()
|
||||||
total = dl.dataset.batches()
|
total = dl.dataset.batches()
|
||||||
|
manual_update = False
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if dl.dataset.index() == 0:
|
if dl.dataset.index() == 0:
|
||||||
|
@ -113,6 +114,10 @@ def _make_infinite_epochs(dl):
|
||||||
if start:
|
if start:
|
||||||
pbar.n = start
|
pbar.n = start
|
||||||
start = 0
|
start = 0
|
||||||
|
manual_update = True
|
||||||
|
# for some reason this is required
|
||||||
|
if manual_update:
|
||||||
|
pbar.n += 1
|
||||||
yield from pbar
|
yield from pbar
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -204,6 +204,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--modality", type=str, default=kwargs["modality"])
|
parser.add_argument("--modality", type=str, default=kwargs["modality"])
|
||||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||||
|
parser.add_argument("--text-language", type=str, default=kwargs["text-language"])
|
||||||
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
|
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
|
||||||
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
|
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
|
||||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||||
|
@ -300,6 +301,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
wav, sr = tts.inference(
|
wav, sr = tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
language=args.language,
|
language=args.language,
|
||||||
|
text_language=args.text_language,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
modality=args.modality.lower(),
|
modality=args.modality.lower(),
|
||||||
references=args.references.split(";") if args.references is not None else [],
|
references=args.references.split(";") if args.references is not None else [],
|
||||||
|
@ -445,7 +447,8 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).")
|
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).")
|
||||||
layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).")
|
layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).")
|
||||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language (Output)", value="en", info="Target language/accent to output.")
|
||||||
|
layout["inference_tts"]["inputs"]["text-language"] = gr.Dropdown(choices=get_languages(), label="Language (Text)", value="en", info="Language the input text is in.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences")
|
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences")
|
||||||
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
|
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user