doh
This commit is contained in:
parent
d4a6298658
commit
2635412291
|
@ -288,7 +288,7 @@ class GptAsrHf2(nn.Module):
|
|||
mel_len = 0
|
||||
else:
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
assert mel_emb.shape[1] <= self.max_mel_frames, f'{mel_emb.shape[1]} > {self.max_mel_frames}'
|
||||
assert mel_emb.shape[-1] <= self.max_mel_frames, f'{mel_emb.shape[-1]} > {self.max_mel_frames}'
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user