From 8250a79b2346e9237b8152ec95b08577be441a33 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 13 Feb 2023 13:48:31 +0000 Subject: [PATCH] Implemented kv_cache "fix" (from https://github.com/152334H/tortoise-tts-fast/commit/1f3c1b5f4a88428410b7e4432f8bc02950bda417); guess I should find out why it's crashing DirectML backend --- tortoise/models/autoregressive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index bb71976..3f797c0 100755 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -230,7 +230,7 @@ class LearnedPositionEmbeddings(nn.Module): return self.emb(torch.arange(0, sl, device=x.device)) def get_fixed_embedding(self, ind, dev): - return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + return self.emb(torch.arange(0, ind, device=dev))[ind-1:ind] def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):