ugh (batchmean actually expects batch=seq_len, and not the actual batch)

This commit is contained in:
mrq 2024-12-07 12:39:01 -06:00
parent 61ed662856
commit 218d0e29fd

View File

@ -88,7 +88,7 @@ def train_feeder(engine, batch, teacher=None):
student_probs = [ F.log_softmax( logit, dim=-1 ) for logit in student_logits ]
teacher_probs = [ F.log_softmax( logit, dim=-1 ) for logit in teacher_logits ]
soft_losses = [ F.kl_div( student, teacher, reduction='sum', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ]
soft_losses = [ F.kl_div( student, teacher, reduction='batchmean', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ]
elif L == "mse":
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_logits, teacher_logits ) ]