ACTUALLY do KD-loss because of an oversight with masked_select outputting 1D tensors that get softmax'd in total
This commit is contained in:
parent
34a66e1052
commit
f97e8b0c7f
|
@ -476,7 +476,7 @@ class Hyperparameters:
|
|||
|
||||
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
||||
teacher_temperature: float = 1.0
|
||||
teacher_loss_fn: str = "kl" # kl | mse
|
||||
teacher_loss_fn: str = "mse" # kl | mse, use either kl_div or mse_loss (most implementations use kl, some literature says you can use mse)
|
||||
|
||||
@dataclass()
|
||||
class Evaluation:
|
||||
|
|
|
@ -60,14 +60,14 @@ def train_feeder(engine, batch, teacher=None):
|
|||
|
||||
# I don't know what to call the last one
|
||||
if L not in ["kl", "mse"]:
|
||||
L = "kd"
|
||||
L = "kl"
|
||||
|
||||
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
|
||||
# we could recreate the target sequence with the ignore indices put in, but that's agony
|
||||
if not engine.module.ignore_inputs_for_loss:
|
||||
student_probs = [ F.log_softmax( student / T, dim=-1 ) for student in output.logits ]
|
||||
teacher_probs = [ F.softmax( teacher / T, dim=-1 ) for teacher in teacher_output.logits ]
|
||||
else:
|
||||
student_logits = [ logit / T for logit in output.logits ]
|
||||
teacher_logits = [ logit / T for logit in teacher_output.logits ]
|
||||
|
||||
if engine.module.ignore_inputs_for_loss or True:
|
||||
task_outputs = {
|
||||
"tts": "resp",
|
||||
"stt": "text",
|
||||
|
@ -85,25 +85,37 @@ def train_feeder(engine, batch, teacher=None):
|
|||
output_lens[batch_index] = input.shape[0]
|
||||
|
||||
# create probability distributions (literature says to have the students already log'd but not the teacher)
|
||||
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( output.logits, output_lens ) ]
|
||||
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
|
||||
student_logits = [ logit[-l:] for logit, l in zip( student_logits, output_lens ) ]
|
||||
teacher_logits = [ logit[-l:] for logit, l in zip( teacher_logits, output_lens ) ]
|
||||
|
||||
# split even further because tensor shapes may change
|
||||
# losses can be done per-token anyways so it's fine
|
||||
student_logits = [ [ l for l in logit ] for logit in student_logits ]
|
||||
teacher_logits = [ [ l for l in logit ] for logit in teacher_logits ]
|
||||
|
||||
# filter out logits that are / would inf
|
||||
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
|
||||
for batch_index in range( batch_size ):
|
||||
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
|
||||
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
|
||||
|
||||
for token_index in range( len( student_logits[batch_index] ) ):
|
||||
filter = -float("inf") # for some unknown reason -inf is poisoning the logits
|
||||
mask_a = student_logits[batch_index][token_index] == filter
|
||||
mask_b = teacher_logits[batch_index][token_index] == filter
|
||||
# remove them from both distributions to keep things synced
|
||||
mask = mask_a | mask_b
|
||||
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
|
||||
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
|
||||
|
||||
student_logits[batch_index][token_index] = torch.masked_select( student_logits[batch_index][token_index], ~mask )
|
||||
teacher_logits[batch_index][token_index] = torch.masked_select( teacher_logits[batch_index][token_index], ~mask )
|
||||
|
||||
# kl-divergence operates on probability distributions
|
||||
# teacher doesn't need to be in logspace but it makes things easier to do so and just pass log_target=True
|
||||
if L == "kl":
|
||||
soft_losses = [ F.kl_div( student, teacher, reduction='sum' ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
student_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in student_logits ]
|
||||
teacher_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in teacher_logits ]
|
||||
|
||||
soft_losses = [ sum([ F.kl_div( s, t, reduction='sum', log_target=True ) for s, t in zip( student, teacher ) ]) / len(student) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
# mse shouldn't operate on probability distributions
|
||||
elif L == "mse":
|
||||
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
else:
|
||||
soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
soft_losses = [ sum([ F.mse_loss( s, t ) for s, t in zip(student, teacher) ]) / len(student) for student, teacher in zip( student_logits, teacher_logits ) ]
|
||||
|
||||
for k in engine.module.loss.keys():
|
||||
engine.module.loss[k] *= (1.0 - A)
|
||||
|
|
Loading…
Reference in New Issue
Block a user