diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index bb71976..3f797c0 100755 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -230,7 +230,7 @@ class LearnedPositionEmbeddings(nn.Module): return self.emb(torch.arange(0, sl, device=x.device)) def get_fixed_embedding(self, ind, dev): - return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + return self.emb(torch.arange(0, ind, device=dev))[ind-1:ind] def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):