ACTUALLY do KD-loss because of an oversight with masked_select outputting 1D tensors that get softmax'd in total

This commit is contained in:
mrq 2024-12-07 09:52:51 -06:00
parent 34a66e1052
commit f97e8b0c7f
2 changed files with 29 additions and 17 deletions

View File

@ -476,7 +476,7 @@ class Hyperparameters:
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
teacher_temperature: float = 1.0
teacher_loss_fn: str = "kl" # kl | mse
teacher_loss_fn: str = "mse" # kl | mse, use either kl_div or mse_loss (most implementations use kl, some literature says you can use mse)
@dataclass()
class Evaluation:

View File

@ -60,14 +60,14 @@ def train_feeder(engine, batch, teacher=None):
# I don't know what to call the last one
if L not in ["kl", "mse"]:
L = "kd"
L = "kl"
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
# we could recreate the target sequence with the ignore indices put in, but that's agony
if not engine.module.ignore_inputs_for_loss:
student_probs = [ F.log_softmax( student / T, dim=-1 ) for student in output.logits ]
teacher_probs = [ F.softmax( teacher / T, dim=-1 ) for teacher in teacher_output.logits ]
else:
student_logits = [ logit / T for logit in output.logits ]
teacher_logits = [ logit / T for logit in teacher_output.logits ]
if engine.module.ignore_inputs_for_loss or True:
task_outputs = {
"tts": "resp",
"stt": "text",
@ -85,25 +85,37 @@ def train_feeder(engine, batch, teacher=None):
output_lens[batch_index] = input.shape[0]
# create probability distributions (literature says to have the students already log'd but not the teacher)
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( output.logits, output_lens ) ]
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
student_logits = [ logit[-l:] for logit, l in zip( student_logits, output_lens ) ]
teacher_logits = [ logit[-l:] for logit, l in zip( teacher_logits, output_lens ) ]
# split even further because tensor shapes may change
# losses can be done per-token anyways so it's fine
student_logits = [ [ l for l in logit ] for logit in student_logits ]
teacher_logits = [ [ l for l in logit ] for logit in teacher_logits ]
# filter out logits that are / would inf
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
for batch_index in range( batch_size ):
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
for token_index in range( len( student_logits[batch_index] ) ):
filter = -float("inf") # for some unknown reason -inf is poisoning the logits
mask_a = student_logits[batch_index][token_index] == filter
mask_b = teacher_logits[batch_index][token_index] == filter
# remove them from both distributions to keep things synced
mask = mask_a | mask_b
mask = mask_a | mask_b
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
student_logits[batch_index][token_index] = torch.masked_select( student_logits[batch_index][token_index], ~mask )
teacher_logits[batch_index][token_index] = torch.masked_select( teacher_logits[batch_index][token_index], ~mask )
# kl-divergence operates on probability distributions
# teacher doesn't need to be in logspace but it makes things easier to do so and just pass log_target=True
if L == "kl":
soft_losses = [ F.kl_div( student, teacher, reduction='sum' ) for student, teacher in zip( student_probs, teacher_probs ) ]
student_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in student_logits ]
teacher_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in teacher_logits ]
soft_losses = [ sum([ F.kl_div( s, t, reduction='sum', log_target=True ) for s, t in zip( student, teacher ) ]) / len(student) for student, teacher in zip( student_probs, teacher_probs ) ]
# mse shouldn't operate on probability distributions
elif L == "mse":
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
else:
soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
soft_losses = [ sum([ F.mse_loss( s, t ) for s, t in zip(student, teacher) ]) / len(student) for student, teacher in zip( student_logits, teacher_logits ) ]
for k in engine.module.loss.keys():
engine.module.loss[k] *= (1.0 - A)