ugh
This commit is contained in:
parent
8641c87611
commit
09e9438941
|
@ -1219,7 +1219,6 @@ class Base_V2(nn.Module):
|
|||
if tasks[0] == "len":
|
||||
# do duration prediction
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
print( logits_aux[0].shape, logits_aux[0] )
|
||||
# it's more accurate this way
|
||||
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user