diff --git a/README.md b/README.md index fe5b240..d845ad7 100755 --- a/README.md +++ b/README.md @@ -16,6 +16,9 @@ Besides a working PyTorch environment, the only hard requirement is [`espeak-ng` ## Install +> [!NOTE] +> There seems to be some form of regression in fancier attention mechanisms in some environments where you might need to explicitly set `attention` to `flash_attention_2` or `sdpa`. + Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`. I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`. diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 192a608..6dd223c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -489,15 +489,15 @@ class Base(nn.Module): self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None # ick, there has to be a better way - hf_attention = self.config.attention if self.config is not None else None - if self.config.attention == "auto": if "flash" in AVAILABLE_ATTENTIONS: self.config.attention = "flash" elif "xformers" in AVAILABLE_ATTENTIONS: self.config.attention = "xformers" else: - self.config.attention = "mem_efficient" + self.config.attention = "sdpa" + + hf_attention = self.config.attention if self.config is not None else None if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]: hf_attention = None