a birdie tells me i should probably use a different optimizer (also preliminary support for native sparse attention but I don't know if I'll use it)

This commit is contained in:
mrq 2025-03-04 14:53:02 -06:00
parent 0451f75e33
commit 1cd24f3381
5 changed files with 88 additions and 14 deletions

View File

@ -143,6 +143,8 @@ def load_engines(training=True, **model_kwargs):
"update_proj_gap": 1,
"proj_type": "std",
})
elif cfg.hyperparameters.optimizer.lower() == "adafactor":
optimizer_class = ml.Adafactor
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
optimizer_class = ml.Adagrad
elif cfg.hyperparameters.optimizer.lower() == "muon":

View File

@ -876,7 +876,7 @@ def example_usage():
available_tasks = ["tts-nar"]
model = AR_NAR_V2(**kwargs).to(cfg.device)
steps = 250 // batch_size
steps = 500 // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -1,3 +1,8 @@
"""
Contains the zoo of alternative attention mechanisms
To-do: align it better with nu-modelling_llama.py's attention selection mechanism
"""
import logging
import torch
@ -36,6 +41,63 @@ try:
except Exception as e:
_logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
# https://github.com/lucidrains/native-sparse-attention-pytorch/
try:
from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention
from native_sparse_attention_pytorch.compress_networks import GroupedMLP
# patiently waiting for specifying attention masks both for padded sequences and non-causal ones
# it might not be necessary since ar+nar-len-llama-8 was able to be "repaired" from the NAR being trained with a causal mask initially
class NativeSparseAttention(SparseAttention):
def __init__(self, config, layer_idx):
dim = config.hidden_size
heads = config.num_attention_heads
dim_head = getattr(config, "head_dim", dim // heads)
kv_heads = config.num_key_value_heads
# to-do: figure out these settings best for VALL-E
compress_block_size = 16
sliding_window_size = 64 # really don't want sliding attention due to the nature of the sequence
selection_block_size = 16
num_selected_blocks = 4
num_compressed_mem_kv = 1
compress_mlp = GroupedMLP(
dim_head = dim_head,
compress_block_size = compress_block_size,
heads = heads,
)
self.config = config
self.layer_idx = layer_idx
super().__init__(
dim = dim,
dim_head = dim_head,
heads = heads,
kv_heads = kv_heads,
sliding_window_size = sliding_window_size,
compress_block_size = compress_block_size,
selection_block_size = selection_block_size,
num_selected_blocks = num_selected_blocks,
num_compressed_mem_kv = num_compressed_mem_kv,
norm = False, # pre/post norm is done here already
use_diff_topk = True,
use_triton_kernel = False,
interpolated_importance_score = False,
query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
compress_mlp = compress_mlp,
compress_mlp_expand_factor = 4.,
strategy_combine_mlp = None
)
AVAILABLE_ATTENTIONS.append("sparse")
except Exception as e:
raise e
_logger.warning(f"Error while querying for `SparseAttention` support: {str(e)}")
pass
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))

View File

@ -451,9 +451,13 @@ class DecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = Attention(config=config, layer_idx=layer_idx)
if config.attn_mode == "sparse":
self.self_attn = NativeSparseAttention(config=config, layer_idx=layer_idx)
else:
self.self_attn = Attention(config=config, layer_idx=layer_idx)
self.mlp = MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -481,18 +485,23 @@ class DecoderLayer(nn.Module):
is_causal = is_causal[0]
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
is_causal=is_causal,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
if self.config.attn_mode == "sparse":
hidden_states = self.self_attn(
hidden_states
)
else:
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
is_causal=is_causal,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -18,6 +18,7 @@ Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
Adafactor = torch.optim.Adafactor
OneCycleLR = torch.optim.lr_scheduler.OneCycleLR
CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR