forgot to reinclude mult by loss factors
This commit is contained in:
parent
b82f0d5c0c
commit
6c49ad06a3
|
@ -1038,21 +1038,21 @@ class Base(nn.Module):
|
||||||
if loss_factor_text > 0.0 and target_text_list:
|
if loss_factor_text > 0.0 and target_text_list:
|
||||||
target_text = torch.cat( target_text_list ).long()
|
target_text = torch.cat( target_text_list ).long()
|
||||||
inputs_text = torch.cat( logits_text )
|
inputs_text = torch.cat( logits_text )
|
||||||
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index )
|
self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) * loss_factor_text
|
||||||
self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text )
|
self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text )
|
||||||
|
|
||||||
loss_factor_prom = self.loss_factor("prom")
|
loss_factor_prom = self.loss_factor("prom")
|
||||||
if loss_factor_prom > 0.0 and target_prom_list:
|
if loss_factor_prom > 0.0 and target_prom_list:
|
||||||
target_prom = torch.cat( target_prom_list ).long()
|
target_prom = torch.cat( target_prom_list ).long()
|
||||||
inputs_prom = torch.cat( logits_prom )
|
inputs_prom = torch.cat( logits_prom )
|
||||||
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index )
|
self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) * loss_factor_prom
|
||||||
self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom )
|
self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom )
|
||||||
|
|
||||||
loss_factor_resp = self.loss_factor("resp")
|
loss_factor_resp = self.loss_factor("resp")
|
||||||
if loss_factor_resp > 0.0 and target_resp_list:
|
if loss_factor_resp > 0.0 and target_resp_list:
|
||||||
target_resp = torch.cat( target_resp_list ).long()
|
target_resp = torch.cat( target_resp_list ).long()
|
||||||
inputs_resp = torch.cat( logits_resp )
|
inputs_resp = torch.cat( logits_resp )
|
||||||
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index )
|
self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) * loss_factor_resp
|
||||||
self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp )
|
self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp )
|
||||||
|
|
||||||
# to-do: compute loss per individual batch to scale per RVQ level
|
# to-do: compute loss per individual batch to scale per RVQ level
|
||||||
|
|
Loading…
Reference in New Issue
Block a user