ugh
This commit is contained in:
parent
8641c87611
commit
09e9438941
|
@ -1219,7 +1219,6 @@ class Base_V2(nn.Module):
|
||||||
if tasks[0] == "len":
|
if tasks[0] == "len":
|
||||||
# do duration prediction
|
# do duration prediction
|
||||||
logits_aux = self.len_decoder( output.logits )
|
logits_aux = self.len_decoder( output.logits )
|
||||||
print( logits_aux[0].shape, logits_aux[0] )
|
|
||||||
# it's more accurate this way
|
# it's more accurate this way
|
||||||
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user