ughh
This commit is contained in:
parent
8ac03aac8a
commit
5f98543d4d
|
@ -1251,7 +1251,7 @@ class Base_V2(nn.Module):
|
|||
# do duration prediction
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# only keep the input
|
||||
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
|
||||
logits = logits_aux
|
||||
|
||||
|
@ -1261,7 +1261,7 @@ class Base_V2(nn.Module):
|
|||
if self.len_parallel_training:
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# only keep the input
|
||||
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
else:
|
||||
logits_aux = None
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user