forgot to reinclude mult by loss factors

This commit is contained in:
mrq 2024-05-27 20:40:21 -05:00
parent b82f0d5c0c
commit 6c49ad06a3

View File

@ -1038,21 +1038,21 @@ class Base(nn.Module):
if loss_factor_text > 0.0 and target_text_list:
target_text = torch.cat( target_text_list ).long()
inputs_text = torch.cat( logits_text )
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index )
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) * loss_factor_text
self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text )
loss_factor_prom = self.loss_factor("prom")
if loss_factor_prom > 0.0 and target_prom_list:
target_prom = torch.cat( target_prom_list ).long()
inputs_prom = torch.cat( logits_prom )
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index )
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) * loss_factor_prom
self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom )
loss_factor_resp = self.loss_factor("resp")
if loss_factor_resp > 0.0 and target_resp_list:
target_resp = torch.cat( target_resp_list ).long()
inputs_resp = torch.cat( logits_resp )
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index )
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) * loss_factor_resp
self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp )
# to-do: compute loss per individual batch to scale per RVQ level