This commit is contained in:
James Betker 2022-01-01 14:29:59 -07:00
parent d4a6298658
commit 2635412291

View File

@ -288,7 +288,7 @@ class GptAsrHf2(nn.Module):
mel_len = 0
else:
mel_emb = self.mel_encoder(mel_inputs)
assert mel_emb.shape[1] <= self.max_mel_frames, f'{mel_emb.shape[1]} > {self.max_mel_frames}'
assert mel_emb.shape[-1] <= self.max_mel_frames, f'{mel_emb.shape[-1]} > {self.max_mel_frames}'
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)