diff --git a/scripts/setup.sh b/scripts/setup.sh index 48565f0..0961f77 100755 --- a/scripts/setup.sh +++ b/scripts/setup.sh @@ -2,9 +2,9 @@ python3 -m venv venv source ./venv/bin/activate -pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -pip3 install -e . +#pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 / cu124 +#pip3 install -e . -mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/ -wget -P ./training/valle/ckpt/ar+nar-retnet-8/ "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth" -wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/raw/main/config.yaml" +mkdir -p ./training/valle/ckpt/ar+nar-llama-8/ +wget -P ./training/valle/ckpt/ar+nar-llama-8/ "https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.pth" +wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/models/config.llama.yaml" diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 79ae047..9ce0360 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -80,7 +80,6 @@ class LlamaAttention(LlamaAttention_Base): else: attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask(), p=dropout_rate) else: - #torch.nn.attention.sdpa_kernel with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_rate) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b1c0fd5..a694bdc 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -500,6 +500,8 @@ class Base(nn.Module): # experimental NAR-only mode self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None + # there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks + """ # ick, there has to be a better way if self.config.attention == "auto": if "flash" in AVAILABLE_ATTENTIONS: @@ -515,6 +517,15 @@ class Base(nn.Module): hf_attention = None if self.config.attention not in AVAILABLE_ATTENTIONS: raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") + """ + + if self.config.attention == "auto": + if "flash" in AVAILABLE_ATTENTIONS: + self.config.attention = "flash_attention_2" + else: + self.config.attention = "sdpa" + + hf_attention = self.config.attention if self.config is not None else None if self.arch_type == "transformer": self.sin_emb = SinusoidalEmbedding(d_model) @@ -746,8 +757,10 @@ class Base(nn.Module): if hasattr( self.model, "embeddings" ): del self.model.embeddings + """ if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention ) + """ if not split_classifiers: self.classifier = nn.Linear(d_model, n_resp_tokens)