From 5f98543d4d930e533d18a6b315c0f424e4e6c135 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 10 Mar 2025 21:18:57 -0500 Subject: [PATCH] ughh --- vall_e/models/base_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 51f1a42..863a557 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1251,7 +1251,7 @@ class Base_V2(nn.Module): # do duration prediction logits_aux = self.len_decoder( output.logits ) # only keep the input - logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] + logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] logits = logits_aux @@ -1261,7 +1261,7 @@ class Base_V2(nn.Module): if self.len_parallel_training: logits_aux = self.len_decoder( output.logits ) # only keep the input - logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] + logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] else: logits_aux = None