From 6c49ad06a3863e3e3b57ab6972d2aa5a11aa1c2c Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 27 May 2024 20:40:21 -0500 Subject: [PATCH] forgot to reinclude mult by loss factors --- vall_e/models/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 321f654..4d83976 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1038,21 +1038,21 @@ class Base(nn.Module): if loss_factor_text > 0.0 and target_text_list: target_text = torch.cat( target_text_list ).long() inputs_text = torch.cat( logits_text ) - self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) + self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) * loss_factor_text self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text ) loss_factor_prom = self.loss_factor("prom") if loss_factor_prom > 0.0 and target_prom_list: target_prom = torch.cat( target_prom_list ).long() inputs_prom = torch.cat( logits_prom ) - self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) + self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) * loss_factor_prom self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom ) loss_factor_resp = self.loss_factor("resp") if loss_factor_resp > 0.0 and target_resp_list: target_resp = torch.cat( target_resp_list ).long() inputs_resp = torch.cat( logits_resp ) - self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) + self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) * loss_factor_resp self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp ) # to-do: compute loss per individual batch to scale per RVQ level