trying (and failing) to nail a weird regression in fancier attentions

This commit is contained in:
mrq 2024-07-29 19:53:37 -05:00
parent c2f5b916fc
commit 55b0121b1a
2 changed files with 6 additions and 3 deletions

View File

@ -16,6 +16,9 @@ Besides a working PyTorch environment, the only hard requirement is [`espeak-ng`
## Install
> [!NOTE]
> There seems to be some form of regression in fancier attention mechanisms in some environments where you might need to explicitly set `attention` to `flash_attention_2` or `sdpa`.
Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`.
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.

View File

@ -489,15 +489,15 @@ class Base(nn.Module):
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
# ick, there has to be a better way
hf_attention = self.config.attention if self.config is not None else None
if self.config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
self.config.attention = "flash"
elif "xformers" in AVAILABLE_ATTENTIONS:
self.config.attention = "xformers"
else:
self.config.attention = "mem_efficient"
self.config.attention = "sdpa"
hf_attention = self.config.attention if self.config is not None else None
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
hf_attention = None