a birdie tells me i should probably use a different optimizer (also preliminary support for native sparse attention but I don't know if I'll use it)
This commit is contained in:
parent
0451f75e33
commit
1cd24f3381
|
@ -143,6 +143,8 @@ def load_engines(training=True, **model_kwargs):
|
||||||
"update_proj_gap": 1,
|
"update_proj_gap": 1,
|
||||||
"proj_type": "std",
|
"proj_type": "std",
|
||||||
})
|
})
|
||||||
|
elif cfg.hyperparameters.optimizer.lower() == "adafactor":
|
||||||
|
optimizer_class = ml.Adafactor
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||||
optimizer_class = ml.Adagrad
|
optimizer_class = ml.Adagrad
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "muon":
|
elif cfg.hyperparameters.optimizer.lower() == "muon":
|
||||||
|
|
|
@ -876,7 +876,7 @@ def example_usage():
|
||||||
available_tasks = ["tts-nar"]
|
available_tasks = ["tts-nar"]
|
||||||
|
|
||||||
model = AR_NAR_V2(**kwargs).to(cfg.device)
|
model = AR_NAR_V2(**kwargs).to(cfg.device)
|
||||||
steps = 250 // batch_size
|
steps = 500 // batch_size
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
"""
|
||||||
|
Contains the zoo of alternative attention mechanisms
|
||||||
|
To-do: align it better with nu-modelling_llama.py's attention selection mechanism
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -36,6 +41,63 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
|
_logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
|
||||||
|
|
||||||
|
# https://github.com/lucidrains/native-sparse-attention-pytorch/
|
||||||
|
try:
|
||||||
|
from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention
|
||||||
|
from native_sparse_attention_pytorch.compress_networks import GroupedMLP
|
||||||
|
|
||||||
|
# patiently waiting for specifying attention masks both for padded sequences and non-causal ones
|
||||||
|
# it might not be necessary since ar+nar-len-llama-8 was able to be "repaired" from the NAR being trained with a causal mask initially
|
||||||
|
class NativeSparseAttention(SparseAttention):
|
||||||
|
def __init__(self, config, layer_idx):
|
||||||
|
dim = config.hidden_size
|
||||||
|
heads = config.num_attention_heads
|
||||||
|
dim_head = getattr(config, "head_dim", dim // heads)
|
||||||
|
kv_heads = config.num_key_value_heads
|
||||||
|
|
||||||
|
# to-do: figure out these settings best for VALL-E
|
||||||
|
compress_block_size = 16
|
||||||
|
sliding_window_size = 64 # really don't want sliding attention due to the nature of the sequence
|
||||||
|
selection_block_size = 16
|
||||||
|
num_selected_blocks = 4
|
||||||
|
num_compressed_mem_kv = 1
|
||||||
|
|
||||||
|
compress_mlp = GroupedMLP(
|
||||||
|
dim_head = dim_head,
|
||||||
|
compress_block_size = compress_block_size,
|
||||||
|
heads = heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
dim = dim,
|
||||||
|
dim_head = dim_head,
|
||||||
|
heads = heads,
|
||||||
|
kv_heads = kv_heads,
|
||||||
|
|
||||||
|
sliding_window_size = sliding_window_size,
|
||||||
|
compress_block_size = compress_block_size,
|
||||||
|
selection_block_size = selection_block_size,
|
||||||
|
num_selected_blocks = num_selected_blocks,
|
||||||
|
num_compressed_mem_kv = num_compressed_mem_kv,
|
||||||
|
|
||||||
|
norm = False, # pre/post norm is done here already
|
||||||
|
use_diff_topk = True,
|
||||||
|
use_triton_kernel = False,
|
||||||
|
interpolated_importance_score = False,
|
||||||
|
query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
|
||||||
|
compress_mlp = compress_mlp,
|
||||||
|
compress_mlp_expand_factor = 4.,
|
||||||
|
strategy_combine_mlp = None
|
||||||
|
)
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS.append("sparse")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
_logger.warning(f"Error while querying for `SparseAttention` support: {str(e)}")
|
||||||
|
pass
|
||||||
|
|
||||||
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
|
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
|
||||||
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
|
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
|
||||||
|
|
|
@ -451,9 +451,13 @@ class DecoderLayer(nn.Module):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
if config.attn_mode == "sparse":
|
||||||
|
self.self_attn = NativeSparseAttention(config=config, layer_idx=layer_idx)
|
||||||
|
else:
|
||||||
|
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
@ -481,18 +485,23 @@ class DecoderLayer(nn.Module):
|
||||||
is_causal = is_causal[0]
|
is_causal = is_causal[0]
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
if self.config.attn_mode == "sparse":
|
||||||
hidden_states=hidden_states,
|
hidden_states = self.self_attn(
|
||||||
attention_mask=attention_mask,
|
hidden_states
|
||||||
is_causal=is_causal,
|
)
|
||||||
position_ids=position_ids,
|
else:
|
||||||
past_key_value=past_key_value,
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
output_attentions=output_attentions,
|
hidden_states=hidden_states,
|
||||||
use_cache=use_cache,
|
attention_mask=attention_mask,
|
||||||
cache_position=cache_position,
|
is_causal=is_causal,
|
||||||
position_embeddings=position_embeddings,
|
position_ids=position_ids,
|
||||||
**kwargs,
|
past_key_value=past_key_value,
|
||||||
)
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
|
|
|
@ -18,6 +18,7 @@ Adam = torch.optim.Adam
|
||||||
AdamW = torch.optim.AdamW
|
AdamW = torch.optim.AdamW
|
||||||
SGD = torch.optim.SGD
|
SGD = torch.optim.SGD
|
||||||
Adagrad = torch.optim.Adagrad
|
Adagrad = torch.optim.Adagrad
|
||||||
|
Adafactor = torch.optim.Adafactor
|
||||||
|
|
||||||
OneCycleLR = torch.optim.lr_scheduler.OneCycleLR
|
OneCycleLR = torch.optim.lr_scheduler.OneCycleLR
|
||||||
CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR
|
CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR
|
||||||
|
|
Loading…
Reference in New Issue
Block a user