From 263541229126a1cf3ab6dda703c639f2417b948a Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 1 Jan 2022 14:29:59 -0700 Subject: [PATCH] doh --- codes/models/gpt_voice/gpt_asr_hf2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index ea4b5303..3a696fa6 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -288,7 +288,7 @@ class GptAsrHf2(nn.Module): mel_len = 0 else: mel_emb = self.mel_encoder(mel_inputs) - assert mel_emb.shape[1] <= self.max_mel_frames, f'{mel_emb.shape[1]} > {self.max_mel_frames}' + assert mel_emb.shape[-1] <= self.max_mel_frames, f'{mel_emb.shape[-1]} > {self.max_mel_frames}' mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) emb = torch.cat([mel_emb, text_emb], dim=1)