diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index abfd086e..8ffdf5de 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -279,7 +279,7 @@ class GptAsrHf2(nn.Module): fake_inputs[:, -cond_used:] = cond_text[:, :cond_used] gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_TOKEN, pad_token_id=0, eos_token_id=0, max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True) - return gen[:, self.max_mel_frames:] + return gen[:, mel_emb.shape[1]:] @register_model