ugh (batchmean actually expects batch=seq_len, and not the actual batch)
This commit is contained in:
parent
61ed662856
commit
218d0e29fd
|
@ -88,7 +88,7 @@ def train_feeder(engine, batch, teacher=None):
|
|||
student_probs = [ F.log_softmax( logit, dim=-1 ) for logit in student_logits ]
|
||||
teacher_probs = [ F.log_softmax( logit, dim=-1 ) for logit in teacher_logits ]
|
||||
|
||||
soft_losses = [ F.kl_div( student, teacher, reduction='sum', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
soft_losses = [ F.kl_div( student, teacher, reduction='batchmean', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
elif L == "mse":
|
||||
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_logits, teacher_logits ) ]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user