From f97e8b0c7f537cef5174ac9acd26e40334e2b368 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 7 Dec 2024 09:52:51 -0600 Subject: [PATCH] ACTUALLY do KD-loss because of an oversight with masked_select outputting 1D tensors that get softmax'd in total --- vall_e/config.py | 2 +- vall_e/train.py | 44 ++++++++++++++++++++++++++++---------------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index f520dde..b79d086 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -476,7 +476,7 @@ class Hyperparameters: teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation teacher_temperature: float = 1.0 - teacher_loss_fn: str = "kl" # kl | mse + teacher_loss_fn: str = "mse" # kl | mse, use either kl_div or mse_loss (most implementations use kl, some literature says you can use mse) @dataclass() class Evaluation: diff --git a/vall_e/train.py b/vall_e/train.py index 03ba1f9..0c2800a 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -60,14 +60,14 @@ def train_feeder(engine, batch, teacher=None): # I don't know what to call the last one if L not in ["kl", "mse"]: - L = "kd" + L = "kl" # determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways) # we could recreate the target sequence with the ignore indices put in, but that's agony - if not engine.module.ignore_inputs_for_loss: - student_probs = [ F.log_softmax( student / T, dim=-1 ) for student in output.logits ] - teacher_probs = [ F.softmax( teacher / T, dim=-1 ) for teacher in teacher_output.logits ] - else: + student_logits = [ logit / T for logit in output.logits ] + teacher_logits = [ logit / T for logit in teacher_output.logits ] + + if engine.module.ignore_inputs_for_loss or True: task_outputs = { "tts": "resp", "stt": "text", @@ -85,25 +85,37 @@ def train_feeder(engine, batch, teacher=None): output_lens[batch_index] = input.shape[0] # create probability distributions (literature says to have the students already log'd but not the teacher) - student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( output.logits, output_lens ) ] - teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ] + student_logits = [ logit[-l:] for logit, l in zip( student_logits, output_lens ) ] + teacher_logits = [ logit[-l:] for logit, l in zip( teacher_logits, output_lens ) ] + + # split even further because tensor shapes may change + # losses can be done per-token anyways so it's fine + student_logits = [ [ l for l in logit ] for logit in student_logits ] + teacher_logits = [ [ l for l in logit ] for logit in teacher_logits ] # filter out logits that are / would inf # this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier) for batch_index in range( batch_size ): - mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf - mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf + for token_index in range( len( student_logits[batch_index] ) ): + filter = -float("inf") # for some unknown reason -inf is poisoning the logits + mask_a = student_logits[batch_index][token_index] == filter + mask_b = teacher_logits[batch_index][token_index] == filter + # remove them from both distributions to keep things synced + mask = mask_a | mask_b - mask = mask_a | mask_b - student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask ) - teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask ) + student_logits[batch_index][token_index] = torch.masked_select( student_logits[batch_index][token_index], ~mask ) + teacher_logits[batch_index][token_index] = torch.masked_select( teacher_logits[batch_index][token_index], ~mask ) + # kl-divergence operates on probability distributions + # teacher doesn't need to be in logspace but it makes things easier to do so and just pass log_target=True if L == "kl": - soft_losses = [ F.kl_div( student, teacher, reduction='sum' ) for student, teacher in zip( student_probs, teacher_probs ) ] + student_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in student_logits ] + teacher_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in teacher_logits ] + + soft_losses = [ sum([ F.kl_div( s, t, reduction='sum', log_target=True ) for s, t in zip( student, teacher ) ]) / len(student) for student, teacher in zip( student_probs, teacher_probs ) ] + # mse shouldn't operate on probability distributions elif L == "mse": - soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ] - else: - soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ] + soft_losses = [ sum([ F.mse_loss( s, t ) for s, t in zip(student, teacher) ]) / len(student) for student, teacher in zip( student_logits, teacher_logits ) ] for k in engine.module.loss.keys(): engine.module.loss[k] *= (1.0 - A)