disable accuracy calc because it breaks with actual batched training even though it shouldn't

This commit is contained in:
mrq 2024-06-04 22:13:44 -05:00
parent d005e24953
commit 687c71e028

View File

@ -162,7 +162,7 @@ class Model(LlmArchClass):
self.backbone.gradient_checkpointing = gradient_checkpointing
self.accuracy_metric = MulticlassAccuracy(
self.accuracy_metric = None if True else MulticlassAccuracy(
vocab_size,
top_k=10,
average="micro",
@ -211,9 +211,10 @@ class Model(LlmArchClass):
nll = loss,
)
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
if self.accuracy_metric is not None:
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
else:
sep = 3
@ -237,12 +238,13 @@ class Model(LlmArchClass):
resp = loss_resp,
)
self.stats = dict(
acc = dict(
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
if self.accuracy_metric is not None:
self.stats = dict(
acc = dict(
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
)
)
)
return output