wrapper attention class for other sdpa backends + xformers seems to have broke...

This commit is contained in:
mrq 2024-08-03 15:12:11 -05:00
parent 9e1989be1b
commit 9564ecda43
3 changed files with 18 additions and 6 deletions

View File

@ -2,9 +2,9 @@
python3 -m venv venv
source ./venv/bin/activate
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install -e .
#pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 / cu124
#pip3 install -e .
mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/
wget -P ./training/valle/ckpt/ar+nar-retnet-8/ "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth"
wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/raw/main/config.yaml"
mkdir -p ./training/valle/ckpt/ar+nar-llama-8/
wget -P ./training/valle/ckpt/ar+nar-llama-8/ "https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.pth"
wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/models/config.llama.yaml"

View File

@ -80,7 +80,6 @@ class LlamaAttention(LlamaAttention_Base):
else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask(), p=dropout_rate)
else:
#torch.nn.attention.sdpa_kernel
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_rate)

View File

@ -500,6 +500,8 @@ class Base(nn.Module):
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
"""
# ick, there has to be a better way
if self.config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
@ -515,6 +517,15 @@ class Base(nn.Module):
hf_attention = None
if self.config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
"""
if self.config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
self.config.attention = "flash_attention_2"
else:
self.config.attention = "sdpa"
hf_attention = self.config.attention if self.config is not None else None
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
@ -746,8 +757,10 @@ class Base(nn.Module):
if hasattr( self.model, "embeddings" ):
del self.model.embeddings
"""
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
"""
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_resp_tokens)