diff --git a/codes/models/gpt_voice/gpt_audio_segmentor.py b/codes/models/gpt_voice/gpt_audio_segmentor.py index 847a545d..5ae57e71 100644 --- a/codes/models/gpt_voice/gpt_audio_segmentor.py +++ b/codes/models/gpt_voice/gpt_audio_segmentor.py @@ -66,6 +66,9 @@ class GptSegmentor(nn.Module): self.stop_head = nn.Linear(model_dim, 1) def forward(self, mel_inputs, mel_lengths): + max_len = mel_lengths.max() # This can be done in the dataset layer, but it is easier to do here. + mel_inputs = mel_inputs[:, :, :max_len] + mel_emb = self.mel_encoder(mel_inputs) mel_lengths = mel_lengths // 4 # The encoder decimates the mel by a factor of 4. mel_emb = mel_emb.permute(0,2,1).contiguous()