disable accuracy calc because it breaks with actual batched training even though it shouldn't
This commit is contained in:
parent
d005e24953
commit
687c71e028
|
@ -162,7 +162,7 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||||
|
|
||||||
self.accuracy_metric = MulticlassAccuracy(
|
self.accuracy_metric = None if True else MulticlassAccuracy(
|
||||||
vocab_size,
|
vocab_size,
|
||||||
top_k=10,
|
top_k=10,
|
||||||
average="micro",
|
average="micro",
|
||||||
|
@ -211,9 +211,10 @@ class Model(LlmArchClass):
|
||||||
nll = loss,
|
nll = loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stats = dict(
|
if self.accuracy_metric is not None:
|
||||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
self.stats = dict(
|
||||||
)
|
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
sep = 3
|
sep = 3
|
||||||
|
@ -237,12 +238,13 @@ class Model(LlmArchClass):
|
||||||
resp = loss_resp,
|
resp = loss_resp,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stats = dict(
|
if self.accuracy_metric is not None:
|
||||||
acc = dict(
|
self.stats = dict(
|
||||||
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
|
acc = dict(
|
||||||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
|
||||||
|
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user