I don't know why this fixes an assert thrown but it does
This commit is contained in:
parent
f50d92ba6c
commit
943fe70c10
|
@ -1341,8 +1341,12 @@ class Base(nn.Module):
|
|||
stats = dict(acc = dict())
|
||||
|
||||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
#classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
|
@ -1428,7 +1432,7 @@ class Base(nn.Module):
|
|||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
else:
|
||||
# nll being natural log likelihood :)))) (I don't know why this completely escaped be originally with thinking it meant something else)
|
||||
# nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else)
|
||||
loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
|
|
|
@ -196,7 +196,7 @@ class NAR(Base):
|
|||
...
|
||||
|
||||
# apply CFG (should probably only apply to NAR quant level 0)
|
||||
if task not in text_task:
|
||||
if task not in text_task + ["len"]:
|
||||
drop_text = False
|
||||
drop_audio = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user