This commit is contained in:
mrq 2025-03-25 23:24:01 -05:00
parent 8641c87611
commit 09e9438941

View File

@ -1219,7 +1219,6 @@ class Base_V2(nn.Module):
if tasks[0] == "len":
# do duration prediction
logits_aux = self.len_decoder( output.logits )
print( logits_aux[0].shape, logits_aux[0] )
# it's more accurate this way
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]