I don't know why this fixes an assert thrown but it does

This commit is contained in:
mrq 2024-11-09 19:04:13 -06:00
parent f50d92ba6c
commit 943fe70c10
2 changed files with 7 additions and 3 deletions

View File

@ -1341,8 +1341,12 @@ class Base(nn.Module):
stats = dict(acc = dict()) stats = dict(acc = dict())
device = logits[0].device device = logits[0].device
batch_size = len(logits)
summed_embeddings_task = [ "stt" ] summed_embeddings_task = [ "stt" ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] #classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
# handles tasks where the prompt has task tokens injected in the middle # handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ): def prompt_input_to_token( input, quant_level ):
@ -1428,7 +1432,7 @@ class Base(nn.Module):
# precision = self.precision_metric( inputs, target ), # precision = self.precision_metric( inputs, target ),
) )
else: else:
# nll being natural log likelihood :)))) (I don't know why this completely escaped be originally with thinking it meant something else) # nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else)
loss = dict( loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
) )

View File

@ -196,7 +196,7 @@ class NAR(Base):
... ...
# apply CFG (should probably only apply to NAR quant level 0) # apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task: if task not in text_task + ["len"]:
drop_text = False drop_text = False
drop_audio = False drop_audio = False