disable accuracy calc because it breaks with actual batched training even though it shouldn't

This commit is contained in:
mrq 2024-06-04 22:13:44 -05:00
parent d005e24953
commit 687c71e028

View File

@ -162,7 +162,7 @@ class Model(LlmArchClass):
self.backbone.gradient_checkpointing = gradient_checkpointing self.backbone.gradient_checkpointing = gradient_checkpointing
self.accuracy_metric = MulticlassAccuracy( self.accuracy_metric = None if True else MulticlassAccuracy(
vocab_size, vocab_size,
top_k=10, top_k=10,
average="micro", average="micro",
@ -211,6 +211,7 @@ class Model(LlmArchClass):
nll = loss, nll = loss,
) )
if self.accuracy_metric is not None:
self.stats = dict( self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
) )
@ -237,6 +238,7 @@ class Model(LlmArchClass):
resp = loss_resp, resp = loss_resp,
) )
if self.accuracy_metric is not None:
self.stats = dict( self.stats = dict(
acc = dict( acc = dict(
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(), text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),