oops
This commit is contained in:
parent
0f7f3ae754
commit
d005e24953
|
@ -229,8 +229,8 @@ class Model(LlmArchClass):
|
|||
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
|
||||
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
|
||||
|
||||
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text")
|
||||
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp")
|
||||
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) / len(logits_text) * self.hyper_config.loss_factor("text")
|
||||
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) / len(logits_resp) * self.hyper_config.loss_factor("resp")
|
||||
|
||||
self.loss = dict(
|
||||
text = loss_text,
|
||||
|
|
Loading…
Reference in New Issue
Block a user