diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 321f654..4d83976 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1038,21 +1038,21 @@ class Base(nn.Module): if loss_factor_text > 0.0 and target_text_list: target_text = torch.cat( target_text_list ).long() inputs_text = torch.cat( logits_text ) - self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) + self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) * loss_factor_text self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text ) loss_factor_prom = self.loss_factor("prom") if loss_factor_prom > 0.0 and target_prom_list: target_prom = torch.cat( target_prom_list ).long() inputs_prom = torch.cat( logits_prom ) - self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) + self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) * loss_factor_prom self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom ) loss_factor_resp = self.loss_factor("resp") if loss_factor_resp > 0.0 and target_resp_list: target_resp = torch.cat( target_resp_list ).long() inputs_resp = torch.cat( logits_resp ) - self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) + self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) * loss_factor_resp self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp ) # to-do: compute loss per individual batch to scale per RVQ level