Implemented kv_cache "fix" (from 1f3c1b5f4a); guess I should find out why it's crashing DirectML backend

This commit is contained in:
mrq 2023-02-13 13:48:31 +00:00
parent 80eeef01fb
commit 8250a79b23

View File

@ -230,7 +230,7 @@ class LearnedPositionEmbeddings(nn.Module):
return self.emb(torch.arange(0, sl, device=x.device)) return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev): def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) return self.emb(torch.arange(0, ind, device=dev))[ind-1:ind]
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):