From 218d0e29fd693ce802dbd13e5281c77eb3ca0cb4 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 7 Dec 2024 12:39:01 -0600 Subject: [PATCH] ugh (batchmean actually expects batch=seq_len, and not the actual batch) --- vall_e/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vall_e/train.py b/vall_e/train.py index 5c3a221..23ac137 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -88,7 +88,7 @@ def train_feeder(engine, batch, teacher=None): student_probs = [ F.log_softmax( logit, dim=-1 ) for logit in student_logits ] teacher_probs = [ F.log_softmax( logit, dim=-1 ) for logit in teacher_logits ] - soft_losses = [ F.kl_div( student, teacher, reduction='sum', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ] + soft_losses = [ F.kl_div( student, teacher, reduction='batchmean', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ] elif L == "mse": soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_logits, teacher_logits ) ]