diff --git a/README.md b/README.md index e244769..af0a291 100755 --- a/README.md +++ b/README.md @@ -16,9 +16,6 @@ Besides a working PyTorch environment, the only hard requirement is [`espeak-ng` ## Install -> [!NOTE] -> There seems to be some form of regression in fancier attention mechanisms in some environments where you might need to explicitly set `attention` to `flash_attention_2` or `sdpa`. - Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`. I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`. @@ -30,7 +27,7 @@ I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`. My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e). -A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh` +A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh`. This will automatically create a `venv`, and download the weights and config file to the right place. ## Train diff --git a/scripts/setup-training.sh b/scripts/setup-training.sh deleted file mode 100755 index 13a9089..0000000 --- a/scripts/setup-training.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -`dirname $0`/setup.sh - -wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/data.tar.gz" -wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/.cache.tar.gz" -tar -xzf ./training/valle/data.tar.gz -C "./training/valle/" data.h5 -tar -xzf ./training/valle/.cache.tar.gz -C "./training/valle/" \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index b026f6c..194fd76 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -223,6 +223,10 @@ class ModelExperimentalSettings: causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions # VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model # however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE + # it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token + + p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len + # to-to: just incorporate this as a task instead # I really need to clean this up @dataclass() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 95be6c2..d01a708 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -26,73 +26,6 @@ def clamp(n, lo, hi): return max(lo, min(n, hi)) class AR_NAR(Base): - @property - def capabilities(self) -> list[str]: - if hasattr(self, "config") and self.config: - return self.config.capabilities - return cfg.model.capabilities - - @property - def causal(self): - return "ar" in self.capabilities - - @property - def n_resp_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.resp_levels - return cfg.model.resp_levels - - @property - def n_max_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.max_levels - return cfg.model.max_levels - - @property - def n_tasks(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.tasks - return cfg.model.tasks - - @property - def n_langs(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.langs - return cfg.model.langs - - @property - def n_tones(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.tones - return cfg.model.tones - - @property - def causal_size(self) -> int: - # 1 for the stop token - # governs how much to shift the logits by - # could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it - if hasattr(self, "config") and self.config: - return self.config.experimental.causal_size - return cfg.model.experimental.causal_size - - @property - def version(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.version - return cfg.model.version - - def _prune(self, l: Tensor, stop = None): - if stop is None: - stop = self.stop_token - indices = (l == stop).nonzero() - if len(indices) == 0: - return l - return l[: indices.min().item()] - - @staticmethod - def _unsqueeze_list(x_list, axis=-1): - return [x.unsqueeze(dim=axis) for x in x_list] - def forward( self, text_list: list[Tensor], @@ -299,7 +232,7 @@ class AR_NAR(Base): # get next in sequence for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): - resps_list = self._unsqueeze_list(sequence_list) + resps_list = [x.unsqueeze(dim=-1) for x in sequence_list] inputs = self.inputs( text_list=text_list, diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index a8233d2..4ce99e9 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -30,7 +30,7 @@ except Exception as e: pass try: - from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Base, LlamaForCausalLM + from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM AVAILABLE_ARCHES.append("llama") except Exception as e: ERROR_ARCHES["llama"] = e @@ -61,11 +61,4 @@ try: from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF AVAILABLE_ARCHES.append("mamba2-hf") except Exception as e: - ERROR_ARCHES["mamba2-hf"] = e - -# desu should remove, perf was very lacking in comparison to regular bitnet -try: - from .mmfreelm import * - AVAILABLE_ARCHES.append("mmfreelm") -except Exception as e: - ERROR_ARCHES["mmfreelm"] = e \ No newline at end of file + ERROR_ARCHES["mamba2-hf"] = e \ No newline at end of file diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 9ce0360..59309fc 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -7,9 +7,18 @@ from torch import Tensor, nn from transformers.cache_utils import Cache from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM -from transformers.models.llama.modeling_llama import LlamaAttention as LlamaAttention_Base, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv -AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] +AVAILABLE_ATTENTIONS = ["sdpa"] + +if torch.backends.cuda.flash_sdp_enabled(): + AVAILABLE_ATTENTIONS.append("flash") + +if torch.backends.cuda.mem_efficient_sdp_enabled(): + AVAILABLE_ATTENTIONS.append("mem_efficient") + +if torch.backends.cuda.math_sdp_enabled(): + AVAILABLE_ATTENTIONS.append("math") try: from xformers.ops import LowerTriangularMask @@ -23,11 +32,11 @@ try: from transformers.utils import is_flash_attn_2_available if is_flash_attn_2_available(): - AVAILABLE_ATTENTIONS.append("flash") + AVAILABLE_ATTENTIONS.append("flash_attention_2") except Exception as e: print("Error while querying for `flash_attn_2` support", e) -class LlamaAttention(LlamaAttention_Base): +class LlamaAttention_Adapted(LlamaAttention): def __init__(self, *args, **kwargs): if 'mode' in kwargs: self.mode = kwargs['mode'] @@ -35,8 +44,101 @@ class LlamaAttention(LlamaAttention_Base): else: self.mode = "math" + if self.mode == "math": + self.mode = torch.nn.attention.SDPBackend.MATH + elif self.mode == "mem_efficient": + self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION + elif self.mode == "flash": + self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION + elif self.mode == "cudnn": + self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION + super().__init__(*args, **kwargs) + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + #with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): + with torch.nn.attention.sdpa_kernel(self.mode): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + """ def forward( self, hidden_states: torch.Tensor, @@ -88,4 +190,5 @@ class LlamaAttention(LlamaAttention_Base): attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value \ No newline at end of file + return attn_output, attn_weights, past_key_value + """ \ No newline at end of file diff --git a/vall_e/models/arch/mmfreelm.py b/vall_e/models/arch/mmfreelm.py deleted file mode 100644 index 86d2fbe..0000000 --- a/vall_e/models/arch/mmfreelm.py +++ /dev/null @@ -1,6 +0,0 @@ -# https://github.com/ridgerchu/matmulfreellm - -import torch -import torch.nn.functional as F - -from mmfreelm.models import HGRNBitConfig, HGRNBitModel \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index a694bdc..a14c850 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -297,60 +297,22 @@ class Metrics(nn.Module): ) class Base(nn.Module): - # to-do: clean up this property mess - @property - def causal(self) -> bool: - raise NotImplementedError - - @property - def n_resp_levels(self) -> int: - raise NotImplementedError - - @property - def n_max_levels(self) -> int: - raise NotImplementedError - - @property - def n_langs(self) -> int: - raise NotImplementedError - - @property - def n_tasks(self) -> int: - raise NotImplementedError - - @property - def n_tones(self) -> int: - raise NotImplementedError - - @property - def causal_size(self) -> int: - raise NotImplementedError - - @property - def version(self) -> int: - return 2 - - @property - def capabilities(self) -> list[str]: - raise NotImplementedError - - @property - def stop_token(self): - if "len" in self.capabilities: - return 0 - if not self.causal: - raise ValueError("Not using stop token!") - return self.n_audio_tokens - - @property - def ignore_index(self): - return -100 - def loss_factor(self, k): if self.config is None: return 1.0 return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0 + def _prune(self, l: Tensor, stop = None): + if stop is None: + stop = self.stop_token + + indices = (l == stop).nonzero() + + if len(indices) == 0: + return l + + return l[: indices.min().item()] + # these probably need to live in an interleaved model, as pattern-ing is targeted for a sole AR model """ def codes_to_pattern(self, codes): @@ -404,7 +366,6 @@ class Base(nn.Module): super().__init__() self.training = training self.config = config - self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True self.n_text_tokens = n_text_tokens self.n_audio_tokens = n_audio_tokens @@ -416,19 +377,35 @@ class Base(nn.Module): self.l_padding = l_padding - arch_type = self.config.arch_type if self.config is not None else "llama" + self.ignore_index = -100 - self.arch_type = arch_type + self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels + self.n_max_levels = self.config.max_levels if self.config else n_resp_levels + self.capabilities = self.config.capabilities if self.config else ["ar", "nar"] + self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True + + self.stop_token = self.n_audio_tokens # id 1024 + self.causal = "ar" in self.capabilities or "len" in self.capabilities + self.version = self.config.version if self.config is not None else 5 + self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if "ar" in self.capabilities else 0) + + self.arch_type = self.config.arch_type if self.config is not None else "llama" # check if requested arch is unavailable if self.arch_type in ERROR_ARCHES: raise ERROR_ARCHES[self.arch_type] + + attention_backend = self.config.attention if self.config is not None else "auto" audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True + n_tasks = self.config.tasks if self.config is not None else 8 + n_langs = self.config.langs if self.config is not None else 2 + n_tones = self.config.tones if self.config is not None else 1 + if "len" not in self.capabilities: # +1 to include the stop token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) @@ -457,7 +434,7 @@ class Base(nn.Module): self.dropout_token = nn.Parameter(torch.zeros(d_model)) # zeros sounds nicer than randn for a special value if self.version == 1: # legacy - n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom + n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) elif self.version < 5: @@ -485,11 +462,11 @@ class Base(nn.Module): # useless since I actually removed using these with the input processing overhaul... if self.version >= 3: - self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None - self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None + self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None + self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None # never actually got added... I kept forgetting to classify all my audio for speaker's tone if self.version >= 4: - self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None + self.tones_emb = Embedding(n_tones, d_model) if n_tones > 0 else None # mamba requires this if a model does both AR and NAR tasks # this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings @@ -501,31 +478,28 @@ class Base(nn.Module): self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None # there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks - """ - # ick, there has to be a better way - if self.config.attention == "auto": - if "flash" in AVAILABLE_ATTENTIONS: - self.config.attention = "flash" - elif "xformers" in AVAILABLE_ATTENTIONS: - self.config.attention = "xformers" + + if attention_backend == "auto": + if "flash_attention_2" in AVAILABLE_ATTENTIONS: + attention_backend = "flash_attention_2" + elif "flash" in AVAILABLE_ATTENTIONS: + attention_backend = "flash" + elif "mem_efficient" in AVAILABLE_ATTENTIONS: + attention_backend = "mem_efficient" + elif "math" in AVAILABLE_ATTENTIONS: + attention_backend = "math" else: - self.config.attention = "sdpa" + attention_backend = "sdpa" - hf_attention = self.config.attention if self.config is not None else None - - if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]: - hf_attention = None - if self.config.attention not in AVAILABLE_ATTENTIONS: - raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") - """ - - if self.config.attention == "auto": - if "flash" in AVAILABLE_ATTENTIONS: - self.config.attention = "flash_attention_2" - else: - self.config.attention = "sdpa" + if attention_backend == "xformers": + attention_backend = "mem_efficient" - hf_attention = self.config.attention if self.config is not None else None + hf_attention = attention_backend + + if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn"]: + hf_attention = None + if attention_backend not in AVAILABLE_ATTENTIONS: + raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") if self.arch_type == "transformer": self.sin_emb = SinusoidalEmbedding(d_model) @@ -654,18 +628,6 @@ class Base(nn.Module): )) self.model = RetNetDecoder(RetNetConfig(**kwargs)) - - # do some funny stuff for LoRA training - """ - if self.gradient_checkpointing: - def make_inputs_require_grads(module, input, output): - for i, t in enumerate(input): - if not isinstance(t, torch.Tensor): - continue - t.requires_grad_(True) - - self.model.register_forward_hook(make_inputs_require_grads) - """ elif self.arch_type == "retnet-hf": kwargs = dict( vocab_size=n_resp_tokens, @@ -757,10 +719,8 @@ class Base(nn.Module): if hasattr( self.model, "embeddings" ): del self.model.embeddings - """ - if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: - self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention ) - """ + if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]: + self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) if not split_classifiers: self.classifier = nn.Linear(d_model, n_resp_tokens) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index c5af2d2..67efc03 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -21,71 +21,6 @@ from tqdm import trange from ..emb.qnt import trim class NAR(Base): - @property - def capabilities(self) -> list[str]: - if hasattr(self, "config") and self.config: - return self.config.capabilities - return cfg.model.capabilities - - @property - def causal(self): - return "len" in self.capabilities - - @property - def n_resp_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.resp_levels - return cfg.model.resp_levels - - @property - def n_max_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.max_levels - return cfg.model.max_levels - - @property - def n_tasks(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.tasks - return cfg.model.tasks - - @property - def n_langs(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.langs - return cfg.model.langs - - @property - def n_tones(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.tones - return cfg.model.tones - - @property - def causal_size(self) -> int: - # 1 for the stop token - # governs how much to shift the logits by - # could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it - return 1 # if self.causal else 0 - - @property - def version(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.version - return cfg.model.version - - def _prune(self, l: Tensor, stop = None): - if stop is None: - stop = self.stop_token - indices = (l == stop).nonzero() - if len(indices) == 0: - return l - return l[: indices.min().item()] - - @staticmethod - def _unsqueeze_list(x_list, axis=-1): - return [x.unsqueeze(dim=axis) for x in x_list] - def forward( self, text_list: list[Tensor], @@ -121,7 +56,7 @@ class NAR(Base): # is training if resps_list is not None: - p_len_task = 0.25 + p_len_task = self.config.experimental.p_len_train if self.config is not None else 0.05 n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set))