wrapper attention class for other sdpa backends + xformers seems to have broke...
This commit is contained in:
parent
9e1989be1b
commit
9564ecda43
|
@ -2,9 +2,9 @@
|
|||
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||
pip3 install -e .
|
||||
#pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 / cu124
|
||||
#pip3 install -e .
|
||||
|
||||
mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/
|
||||
wget -P ./training/valle/ckpt/ar+nar-retnet-8/ "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth"
|
||||
wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/raw/main/config.yaml"
|
||||
mkdir -p ./training/valle/ckpt/ar+nar-llama-8/
|
||||
wget -P ./training/valle/ckpt/ar+nar-llama-8/ "https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.pth"
|
||||
wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/models/config.llama.yaml"
|
||||
|
|
|
@ -80,7 +80,6 @@ class LlamaAttention(LlamaAttention_Base):
|
|||
else:
|
||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask(), p=dropout_rate)
|
||||
else:
|
||||
#torch.nn.attention.sdpa_kernel
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_rate)
|
||||
|
||||
|
|
|
@ -500,6 +500,8 @@ class Base(nn.Module):
|
|||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
||||
|
||||
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
|
||||
"""
|
||||
# ick, there has to be a better way
|
||||
if self.config.attention == "auto":
|
||||
if "flash" in AVAILABLE_ATTENTIONS:
|
||||
|
@ -515,6 +517,15 @@ class Base(nn.Module):
|
|||
hf_attention = None
|
||||
if self.config.attention not in AVAILABLE_ATTENTIONS:
|
||||
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||
"""
|
||||
|
||||
if self.config.attention == "auto":
|
||||
if "flash" in AVAILABLE_ATTENTIONS:
|
||||
self.config.attention = "flash_attention_2"
|
||||
else:
|
||||
self.config.attention = "sdpa"
|
||||
|
||||
hf_attention = self.config.attention if self.config is not None else None
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||
|
@ -746,8 +757,10 @@ class Base(nn.Module):
|
|||
if hasattr( self.model, "embeddings" ):
|
||||
del self.model.embeddings
|
||||
|
||||
"""
|
||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
|
||||
"""
|
||||
|
||||
if not split_classifiers:
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
|
Loading…
Reference in New Issue
Block a user