ughh
This commit is contained in:
parent
8ac03aac8a
commit
5f98543d4d
|
@ -1251,7 +1251,7 @@ class Base_V2(nn.Module):
|
||||||
# do duration prediction
|
# do duration prediction
|
||||||
logits_aux = self.len_decoder( output.logits )
|
logits_aux = self.len_decoder( output.logits )
|
||||||
# only keep the input
|
# only keep the input
|
||||||
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
|
|
||||||
logits = logits_aux
|
logits = logits_aux
|
||||||
|
|
||||||
|
@ -1261,7 +1261,7 @@ class Base_V2(nn.Module):
|
||||||
if self.len_parallel_training:
|
if self.len_parallel_training:
|
||||||
logits_aux = self.len_decoder( output.logits )
|
logits_aux = self.len_decoder( output.logits )
|
||||||
# only keep the input
|
# only keep the input
|
||||||
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
else:
|
else:
|
||||||
logits_aux = None
|
logits_aux = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user