disable accuracy calc because it breaks with actual batched training even though it shouldn't
This commit is contained in:
parent
d005e24953
commit
687c71e028
|
@ -162,7 +162,7 @@ class Model(LlmArchClass):
|
|||
|
||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
self.accuracy_metric = None if True else MulticlassAccuracy(
|
||||
vocab_size,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
|
@ -211,9 +211,10 @@ class Model(LlmArchClass):
|
|||
nll = loss,
|
||||
)
|
||||
|
||||
self.stats = dict(
|
||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||
)
|
||||
if self.accuracy_metric is not None:
|
||||
self.stats = dict(
|
||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||
)
|
||||
|
||||
else:
|
||||
sep = 3
|
||||
|
@ -237,12 +238,13 @@ class Model(LlmArchClass):
|
|||
resp = loss_resp,
|
||||
)
|
||||
|
||||
self.stats = dict(
|
||||
acc = dict(
|
||||
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
|
||||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||
if self.accuracy_metric is not None:
|
||||
self.stats = dict(
|
||||
acc = dict(
|
||||
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
|
||||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user