This commit is contained in:
mrq 2025-03-10 21:18:57 -05:00
parent 8ac03aac8a
commit 5f98543d4d

View File

@ -1251,7 +1251,7 @@ class Base_V2(nn.Module):
# do duration prediction
logits_aux = self.len_decoder( output.logits )
# only keep the input
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
logits = logits_aux
@ -1261,7 +1261,7 @@ class Base_V2(nn.Module):
if self.len_parallel_training:
logits_aux = self.len_decoder( output.logits )
# only keep the input
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
else:
logits_aux = None