From 1cd24f33810b2d679accbce2d8cc355a3e6fe874 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Mar 2025 14:53:02 -0600 Subject: [PATCH] a birdie tells me i should probably use a different optimizer (also preliminary support for native sparse attention but I don't know if I'll use it) --- vall_e/engines/__init__.py | 2 + vall_e/models/ar_nar_v2.py | 2 +- vall_e/models/arch/attention/__init__.py | 62 ++++++++++++++++++++++++ vall_e/models/arch/llama.py | 35 ++++++++----- vall_e/utils/ml.py | 1 + 5 files changed, 88 insertions(+), 14 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 150b717..5ec6569 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -143,6 +143,8 @@ def load_engines(training=True, **model_kwargs): "update_proj_gap": 1, "proj_type": "std", }) + elif cfg.hyperparameters.optimizer.lower() == "adafactor": + optimizer_class = ml.Adafactor elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad elif cfg.hyperparameters.optimizer.lower() == "muon": diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index ab6b44d..7b8915b 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -876,7 +876,7 @@ def example_usage(): available_tasks = ["tts-nar"] model = AR_NAR_V2(**kwargs).to(cfg.device) - steps = 250 // batch_size + steps = 500 // batch_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/arch/attention/__init__.py b/vall_e/models/arch/attention/__init__.py index b4e0ca3..161fd7c 100644 --- a/vall_e/models/arch/attention/__init__.py +++ b/vall_e/models/arch/attention/__init__.py @@ -1,3 +1,8 @@ +""" +Contains the zoo of alternative attention mechanisms +To-do: align it better with nu-modelling_llama.py's attention selection mechanism +""" + import logging import torch @@ -36,6 +41,63 @@ try: except Exception as e: _logger.warning(f"Error while querying for `fused_attn` support: {str(e)}") +# https://github.com/lucidrains/native-sparse-attention-pytorch/ +try: + from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention + from native_sparse_attention_pytorch.compress_networks import GroupedMLP + + # patiently waiting for specifying attention masks both for padded sequences and non-causal ones + # it might not be necessary since ar+nar-len-llama-8 was able to be "repaired" from the NAR being trained with a causal mask initially + class NativeSparseAttention(SparseAttention): + def __init__(self, config, layer_idx): + dim = config.hidden_size + heads = config.num_attention_heads + dim_head = getattr(config, "head_dim", dim // heads) + kv_heads = config.num_key_value_heads + + # to-do: figure out these settings best for VALL-E + compress_block_size = 16 + sliding_window_size = 64 # really don't want sliding attention due to the nature of the sequence + selection_block_size = 16 + num_selected_blocks = 4 + num_compressed_mem_kv = 1 + + compress_mlp = GroupedMLP( + dim_head = dim_head, + compress_block_size = compress_block_size, + heads = heads, + ) + + self.config = config + self.layer_idx = layer_idx + + super().__init__( + dim = dim, + dim_head = dim_head, + heads = heads, + kv_heads = kv_heads, + + sliding_window_size = sliding_window_size, + compress_block_size = compress_block_size, + selection_block_size = selection_block_size, + num_selected_blocks = num_selected_blocks, + num_compressed_mem_kv = num_compressed_mem_kv, + + norm = False, # pre/post norm is done here already + use_diff_topk = True, + use_triton_kernel = False, + interpolated_importance_score = False, + query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course + compress_mlp = compress_mlp, + compress_mlp_expand_factor = 4., + strategy_combine_mlp = None + ) + + AVAILABLE_ATTENTIONS.append("sparse") +except Exception as e: + raise e + _logger.warning(f"Error while querying for `SparseAttention` support: {str(e)}") + pass is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())) is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index a8f577d..863cbec 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -451,9 +451,13 @@ class DecoderLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() + self.config = config self.hidden_size = config.hidden_size - self.self_attn = Attention(config=config, layer_idx=layer_idx) + if config.attn_mode == "sparse": + self.self_attn = NativeSparseAttention(config=config, layer_idx=layer_idx) + else: + self.self_attn = Attention(config=config, layer_idx=layer_idx) self.mlp = MLP(config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -481,18 +485,23 @@ class DecoderLayer(nn.Module): is_causal = is_causal[0] # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - is_causal=is_causal, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + if self.config.attn_mode == "sparse": + hidden_states = self.self_attn( + hidden_states + ) + else: + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + is_causal=is_causal, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = residual + hidden_states # Fully Connected diff --git a/vall_e/utils/ml.py b/vall_e/utils/ml.py index 39a430f..d17aa9a 100755 --- a/vall_e/utils/ml.py +++ b/vall_e/utils/ml.py @@ -18,6 +18,7 @@ Adam = torch.optim.Adam AdamW = torch.optim.AdamW SGD = torch.optim.SGD Adagrad = torch.optim.Adagrad +Adafactor = torch.optim.Adafactor OneCycleLR = torch.optim.lr_scheduler.OneCycleLR CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR