forgot to reinclude mult by loss factors
This commit is contained in:
parent
b82f0d5c0c
commit
6c49ad06a3
|
@ -1038,21 +1038,21 @@ class Base(nn.Module):
|
|||
if loss_factor_text > 0.0 and target_text_list:
|
||||
target_text = torch.cat( target_text_list ).long()
|
||||
inputs_text = torch.cat( logits_text )
|
||||
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index )
|
||||
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) * loss_factor_text
|
||||
self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text )
|
||||
|
||||
loss_factor_prom = self.loss_factor("prom")
|
||||
if loss_factor_prom > 0.0 and target_prom_list:
|
||||
target_prom = torch.cat( target_prom_list ).long()
|
||||
inputs_prom = torch.cat( logits_prom )
|
||||
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index )
|
||||
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) * loss_factor_prom
|
||||
self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom )
|
||||
|
||||
loss_factor_resp = self.loss_factor("resp")
|
||||
if loss_factor_resp > 0.0 and target_resp_list:
|
||||
target_resp = torch.cat( target_resp_list ).long()
|
||||
inputs_resp = torch.cat( logits_resp )
|
||||
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index )
|
||||
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) * loss_factor_resp
|
||||
self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp )
|
||||
|
||||
# to-do: compute loss per individual batch to scale per RVQ level
|
||||
|
|
Loading…
Reference in New Issue
Block a user