From 943fe70c100394ebe7d2b1ef782b7237cbbcc1bd Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 9 Nov 2024 19:04:13 -0600 Subject: [PATCH] I don't know why this fixes an assert thrown but it does --- vall_e/models/base.py | 8 ++++++-- vall_e/models/nar.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7c86d0b..1737b45 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1341,8 +1341,12 @@ class Base(nn.Module): stats = dict(acc = dict()) device = logits[0].device + batch_size = len(logits) summed_embeddings_task = [ "stt" ] - classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] + #classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] + + tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] + classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -1428,7 +1432,7 @@ class Base(nn.Module): # precision = self.precision_metric( inputs, target ), ) else: - # nll being natural log likelihood :)))) (I don't know why this completely escaped be originally with thinking it meant something else) + # nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else) loss = dict( nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size ) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index c3c5690..8b014a9 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -196,7 +196,7 @@ class NAR(Base): ... # apply CFG (should probably only apply to NAR quant level 0) - if task not in text_task: + if task not in text_task + ["len"]: drop_text = False drop_audio = False