I don't know why this fixes an assert thrown but it does
This commit is contained in:
parent
f50d92ba6c
commit
943fe70c10
|
@ -1341,8 +1341,12 @@ class Base(nn.Module):
|
||||||
stats = dict(acc = dict())
|
stats = dict(acc = dict())
|
||||||
|
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
|
batch_size = len(logits)
|
||||||
summed_embeddings_task = [ "stt" ]
|
summed_embeddings_task = [ "stt" ]
|
||||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
#classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
|
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||||
|
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
@ -1428,7 +1432,7 @@ class Base(nn.Module):
|
||||||
# precision = self.precision_metric( inputs, target ),
|
# precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# nll being natural log likelihood :)))) (I don't know why this completely escaped be originally with thinking it meant something else)
|
# nll being natural log likelihood :)))) (I don't know why this completely escaped me originally with thinking it meant something else)
|
||||||
loss = dict(
|
loss = dict(
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||||
)
|
)
|
||||||
|
|
|
@ -196,7 +196,7 @@ class NAR(Base):
|
||||||
...
|
...
|
||||||
|
|
||||||
# apply CFG (should probably only apply to NAR quant level 0)
|
# apply CFG (should probably only apply to NAR quant level 0)
|
||||||
if task not in text_task:
|
if task not in text_task + ["len"]:
|
||||||
drop_text = False
|
drop_text = False
|
||||||
drop_audio = False
|
drop_audio = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user