disable accuracy calc because it breaks with actual batched training even though it shouldn't

This commit is contained in:
mrq 2024-06-04 22:13:44 -05:00
parent d005e24953
commit 687c71e028

View File

@ -162,7 +162,7 @@ class Model(LlmArchClass):
self.backbone.gradient_checkpointing = gradient_checkpointing self.backbone.gradient_checkpointing = gradient_checkpointing
self.accuracy_metric = MulticlassAccuracy( self.accuracy_metric = None if True else MulticlassAccuracy(
vocab_size, vocab_size,
top_k=10, top_k=10,
average="micro", average="micro",
@ -211,9 +211,10 @@ class Model(LlmArchClass):
nll = loss, nll = loss,
) )
self.stats = dict( if self.accuracy_metric is not None:
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() self.stats = dict(
) acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
else: else:
sep = 3 sep = 3
@ -237,12 +238,13 @@ class Model(LlmArchClass):
resp = loss_resp, resp = loss_resp,
) )
self.stats = dict( if self.accuracy_metric is not None:
acc = dict( self.stats = dict(
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(), acc = dict(
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(), text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
)
) )
)
return output return output