This commit is contained in:
mrq 2024-06-04 22:10:04 -05:00
parent 0f7f3ae754
commit d005e24953

View File

@ -229,8 +229,8 @@ class Model(LlmArchClass):
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text")
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp")
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) / len(logits_text) * self.hyper_config.loss_factor("text")
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) / len(logits_resp) * self.hyper_config.loss_factor("resp")
self.loss = dict(
text = loss_text,