a birdie tells me i should probably use a different optimizer (also preliminary support for native sparse attention but I don't know if I'll use it)
This commit is contained in:
parent
0451f75e33
commit
1cd24f3381
|
@ -143,6 +143,8 @@ def load_engines(training=True, **model_kwargs):
|
|||
"update_proj_gap": 1,
|
||||
"proj_type": "std",
|
||||
})
|
||||
elif cfg.hyperparameters.optimizer.lower() == "adafactor":
|
||||
optimizer_class = ml.Adafactor
|
||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||
optimizer_class = ml.Adagrad
|
||||
elif cfg.hyperparameters.optimizer.lower() == "muon":
|
||||
|
|
|
@ -876,7 +876,7 @@ def example_usage():
|
|||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR_V2(**kwargs).to(cfg.device)
|
||||
steps = 250 // batch_size
|
||||
steps = 500 // batch_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
"""
|
||||
Contains the zoo of alternative attention mechanisms
|
||||
To-do: align it better with nu-modelling_llama.py's attention selection mechanism
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
@ -36,6 +41,63 @@ try:
|
|||
except Exception as e:
|
||||
_logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
|
||||
|
||||
# https://github.com/lucidrains/native-sparse-attention-pytorch/
|
||||
try:
|
||||
from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention
|
||||
from native_sparse_attention_pytorch.compress_networks import GroupedMLP
|
||||
|
||||
# patiently waiting for specifying attention masks both for padded sequences and non-causal ones
|
||||
# it might not be necessary since ar+nar-len-llama-8 was able to be "repaired" from the NAR being trained with a causal mask initially
|
||||
class NativeSparseAttention(SparseAttention):
|
||||
def __init__(self, config, layer_idx):
|
||||
dim = config.hidden_size
|
||||
heads = config.num_attention_heads
|
||||
dim_head = getattr(config, "head_dim", dim // heads)
|
||||
kv_heads = config.num_key_value_heads
|
||||
|
||||
# to-do: figure out these settings best for VALL-E
|
||||
compress_block_size = 16
|
||||
sliding_window_size = 64 # really don't want sliding attention due to the nature of the sequence
|
||||
selection_block_size = 16
|
||||
num_selected_blocks = 4
|
||||
num_compressed_mem_kv = 1
|
||||
|
||||
compress_mlp = GroupedMLP(
|
||||
dim_head = dim_head,
|
||||
compress_block_size = compress_block_size,
|
||||
heads = heads,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
super().__init__(
|
||||
dim = dim,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
kv_heads = kv_heads,
|
||||
|
||||
sliding_window_size = sliding_window_size,
|
||||
compress_block_size = compress_block_size,
|
||||
selection_block_size = selection_block_size,
|
||||
num_selected_blocks = num_selected_blocks,
|
||||
num_compressed_mem_kv = num_compressed_mem_kv,
|
||||
|
||||
norm = False, # pre/post norm is done here already
|
||||
use_diff_topk = True,
|
||||
use_triton_kernel = False,
|
||||
interpolated_importance_score = False,
|
||||
query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
|
||||
compress_mlp = compress_mlp,
|
||||
compress_mlp_expand_factor = 4.,
|
||||
strategy_combine_mlp = None
|
||||
)
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("sparse")
|
||||
except Exception as e:
|
||||
raise e
|
||||
_logger.warning(f"Error while querying for `SparseAttention` support: {str(e)}")
|
||||
pass
|
||||
|
||||
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
|
||||
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
|
||||
|
|
|
@ -451,9 +451,13 @@ class DecoderLayer(nn.Module):
|
|||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
||||
if config.attn_mode == "sparse":
|
||||
self.self_attn = NativeSparseAttention(config=config, layer_idx=layer_idx)
|
||||
else:
|
||||
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = MLP(config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
@ -481,18 +485,23 @@ class DecoderLayer(nn.Module):
|
|||
is_causal = is_causal[0]
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
if self.config.attn_mode == "sparse":
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
|
|
|
@ -18,6 +18,7 @@ Adam = torch.optim.Adam
|
|||
AdamW = torch.optim.AdamW
|
||||
SGD = torch.optim.SGD
|
||||
Adagrad = torch.optim.Adagrad
|
||||
Adafactor = torch.optim.Adafactor
|
||||
|
||||
OneCycleLR = torch.optim.lr_scheduler.OneCycleLR
|
||||
CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR
|
||||
|
|
Loading…
Reference in New Issue
Block a user