From 687c71e028cb7b1de4b229150990ba218c6da285 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Jun 2024 22:13:44 -0500 Subject: [PATCH] disable accuracy calc because it breaks with actual batched training even though it shouldn't --- vall_e/models/experimental.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 3e4f755..9f8e37f 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -162,7 +162,7 @@ class Model(LlmArchClass): self.backbone.gradient_checkpointing = gradient_checkpointing - self.accuracy_metric = MulticlassAccuracy( + self.accuracy_metric = None if True else MulticlassAccuracy( vocab_size, top_k=10, average="micro", @@ -211,9 +211,10 @@ class Model(LlmArchClass): nll = loss, ) - self.stats = dict( - acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() - ) + if self.accuracy_metric is not None: + self.stats = dict( + acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() + ) else: sep = 3 @@ -237,12 +238,13 @@ class Model(LlmArchClass): resp = loss_resp, ) - self.stats = dict( - acc = dict( - text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(), - resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(), + if self.accuracy_metric is not None: + self.stats = dict( + acc = dict( + text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(), + resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(), + ) ) - ) return output