From d005e2495340fa63cda9c1b90d8796a0beb172bb Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Jun 2024 22:10:04 -0500 Subject: [PATCH] oops --- vall_e/models/experimental.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 9b2c5fb..3e4f755 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -229,8 +229,8 @@ class Model(LlmArchClass): logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ] logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ] - loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text") - loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp") + loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) / len(logits_text) * self.hyper_config.loss_factor("text") + loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) / len(logits_resp) * self.hyper_config.loss_factor("resp") self.loss = dict( text = loss_text,