This commit is contained in:
mrq 2024-06-05 20:30:43 -05:00
parent 880b4ecd1b
commit ff6fe6f1bc
14 changed files with 300 additions and 368 deletions

View File

@ -1,30 +0,0 @@
"""
# https://github.com/enhuiz/vall-e/
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
super().__init__()
self.eps = eps
self.emb = nn.Embedding(n_levels, d_model * 2)
self.k = k
self.c = c
nn.init.zeros_(self.emb.weight)
def forward(self, x, l):
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
# performed worse than vanilla LayerNorm.
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
# Did they use AdaNorm inside AdaLN? (as follows)
h = self.c * (1 - (self.k * h).detach()) * h
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
y = logγ.exp() * h + β
return y

56
vall_e/models/arch/__init__.py Executable file
View File

@ -0,0 +1,56 @@
AVAILABLE_ARCHES = []
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
AVAILABLE_ARCHES.append("transformer")
except Exception as e:
print("Error importing `transformer` arch:", e)
pass
try:
from .retnet import RetNetDecoder, RetNetConfig
AVAILABLE_ARCHES.append("retnet")
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
try:
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
AVAILABLE_ARCHES.append("retnet-ts")
except Exception as e:
print("Error importing `retnet-ts` arch:", e)
pass
try:
from .retnet_syncdoth.retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF, RetNetForCausalLM
AVAILABLE_ARCHES.append("retnet-hf")
except Exception as e:
print("Error importing `retnet-hf` arch:", e)
pass
try:
from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Base, LlamaForCausalLM
AVAILABLE_ARCHES.append("llama")
except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from .bitnet import BitNetTransformer
AVAILABLE_ARCHES.append("bitnet")
except Exception as e:
print("Error importing `bitnet` arch:", e)
pass
try:
from .mixtral import MixtralModel, MixtralConfig
AVAILABLE_ARCHES.append("mixtral")
except Exception as e:
print("Error importing `mixtral` arch:", e)
try:
from .mamba import MambaMixelModel, MambaLMHeadModel
AVAILABLE_ARCHES.append("mamba")
AVAILABLE_ARCHES.append("mamba2")
except Exception as e:
print("Error importing `mamba` arch:", e)

View File

@ -0,0 +1,51 @@
# https://github.com/kyegomez/BitNet
from torch import Tensor, nn
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
# re-enable logging because zetascale fucking sucks
import logging
logging.getLogger().setLevel(logging.DEBUG)
# override for wrapping checkpointing
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
skip = x
for attn, ffn in zip(self.layers, self.ffn_layers):
if x.requires_grad and self.gradient_checkpointing:
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
else:
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
x = x + skip
x = ffn(x) + x
return x
BitNetTransformerBlock.forward = BitNetTransformerBlock_forward
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
class BitNetTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_tokens: int,
heads=8,
ff_mult=4,
gradient_checkpointing = True
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
self.transformer.gradient_checkpointing = gradient_checkpointing
def forward(self, x):
x = self.transformer(x)
return self.norm( x )
"""
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x)
return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
"""

View File

@ -0,0 +1,92 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
import torch
from typing import Literal, overload, Optional, Tuple
from torch import Tensor, nn
from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention as LlamaAttention_Base, apply_rotary_pos_emb
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash")
except Exception as e:
print("Error while querying for `flash_attn_2` support", e)
class LlamaAttention(LlamaAttention_Base):
def __init__(self, *args, **kwargs):
if 'mode' in kwargs:
self.mode = kwargs['mode']
kwargs.pop("mode")
else:
self.mode = "math"
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
if self.mode == "xformers":
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None, p=dropout_rate)
else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask(), p=dropout_rate)
else:
#torch.nn.attention.sdpa_kernel
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_rate)
attn_weights = None
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value

View File

@ -0,0 +1,30 @@
# https://github.com/state-spaces/mamba
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
if hidden_states is None:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward

View File

@ -0,0 +1,45 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
import torch
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
# This is required because batch sizes > 1 throws errors
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
top_x_list = top_x.tolist()
idx_list = idx.tolist()
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward

View File

View File

@ -1,9 +1,9 @@
# https://github.com/syncdoth/RetNet/ # https://github.com/syncdoth/RetNet/
from ..ext.retnet_hf.configuration_retnet import RetNetConfig from ....ext.retnet_hf.configuration_retnet import RetNetConfig
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM from ....ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM
# things we're overriding or required to override # things we're overriding or required to override
from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos from ....ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
import torch import torch
import math import math

View File

@ -1,9 +1,9 @@
# https://github.com/syncdoth/RetNet/ # https://github.com/syncdoth/RetNet/
from ..ext.retnet_ts.config import RetNetConfig from ....ext.retnet_ts.config import RetNetConfig
from ..ext.retnet_ts.retnet import RetNetModel as RetNetDecoder from ....ext.retnet_ts.retnet import RetNetModel as RetNetDecoder
# things we're overriding or required to override # things we're overriding or required to override
from ..ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos from ....ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
import torch import torch
import math import math

View File

@ -14,7 +14,7 @@ from einops import rearrange
from torch import Tensor, einsum, nn from torch import Tensor, einsum, nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from ..utils import wrapper as ml from ...utils import wrapper as ml
class AdaLN(nn.Module): class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2): def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):

View File

@ -16,268 +16,10 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
except Exception as e:
print("Error importing `transformer` arch:", e)
pass
try:
#from .retnet import RetNetDecoder, RetNetConfig
from .retnet_ts import RetNetDecoder, RetNetConfig
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
from .retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF
"""
try:
except Exception as e:
print("Error importing `retnet-hf` arch:", e)
pass
"""
try:
from transformers import LlamaModel, LlamaConfig
except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from transformers import MistralModel, MistralConfig
except Exception as e:
print("Error importing `mistral` arch:", e)
pass
try:
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
# re-enable logging because zetascale fucking sucks
import logging
logging.getLogger().setLevel(logging.DEBUG)
# override for wrapping checkpointing
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
skip = x
for attn, ffn in zip(self.layers, self.ffn_layers):
if x.requires_grad and self.gradient_checkpointing:
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
else:
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
x = x + skip
x = ffn(x) + x
return x
BitNetTransformerBlock.forward = BitNetTransformerBlock_forward
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
class BitNetTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_tokens: int,
heads=8,
ff_mult=4,
gradient_checkpointing = True
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
self.transformer.gradient_checkpointing = gradient_checkpointing
def forward(self, x):
x = self.transformer(x)
return self.norm( x )
"""
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x)
return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
"""
except Exception as e:
print("Error importing `bitnet` arch:", e)
pass
try:
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
# This is required because batch sizes > 1 throws errors
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
top_x_list = top_x.tolist()
idx_list = idx.tolist()
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward
except Exception as e:
print("Error importing `mixtral` arch:", e)
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
if hidden_states is None:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward
except Exception as e:
print("Error importing `mixtral` arch:", e)
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash")
except Exception as e:
raise e
try:
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
class Llama_Attention(LlamaAttention):
def __init__(self, *args, **kwargs):
if 'mode' in kwargs:
self.mode = kwargs['mode']
kwargs.pop("mode")
else:
self.mode = "math"
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
if self.mode == "xformers":
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None, p=dropout_rate)
else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask(), p=dropout_rate)
else:
#torch.nn.attention.sdpa_kernel
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_rate)
attn_weights = None
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
except Exception as e:
print("Error creating modified `LLamaAttention`:", e)
def _create_mask(l, device): def _create_mask(l, device):
"""1 is valid region and 0 is invalid.""" """1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -751,7 +493,7 @@ class Base(nn.Module):
raise RuntimeError(f'Unknown arch specified: {self.arch_type}') raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.hyper_config.attention ) self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.hyper_config.attention )
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)

View File

@ -24,73 +24,19 @@ import math
from einops import rearrange from einops import rearrange
from tqdm import trange from tqdm import trange
AVAILABLE_ARCHES = [] from .arch import *
try: if cfg.model.arch_type not in AVAILABLE_ARCHES:
from transformers import LlamaForCausalLM, LlamaConfig raise ValueError(f"Requesting arch `{cfg.model.arch_type}` but not available")
AVAILABLE_ARCHES.append("llama")
except Exception as e:
print("Error importing `llama` arch:", e)
pass
try: if cfg.model.arch_type in ["mamba","mamba2"]:
from .retnet_hf import RetNetConfig
from ..ext.retnet_hf.modeling_retnet import RetNetForCausalLM
AVAILABLE_ARCHES.append("retnet")
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward
AVAILABLE_ARCHES.append("mamba")
AVAILABLE_ARCHES.append("mamba2")
except Exception as e:
print("Error importing `mamba` arch:", e)
pass
SELECTED_ARCH = cfg.model.arch_type
if SELECTED_ARCH not in AVAILABLE_ARCHES:
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
if SELECTED_ARCH in ["mamba","mamba2"]:
LlmArchClass = MambaLMHeadModel LlmArchClass = MambaLMHeadModel
elif SELECTED_ARCH == "llama": elif cfg.model.arch_type == "llama":
LlmArchClass = LlamaForCausalLM LlmArchClass = LlamaForCausalLM
elif SELECTED_ARCH == "retnet": elif cfg.model.arch_type == "retnet":
LlmArchClass = RetNetForCausalLM LlmArchClass = RetNetForCausalLM
else: else:
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available") raise ValueError(f"Requesting arch `{cfg.model.arch_type}` but not available")
class Model(LlmArchClass): class Model(LlmArchClass):
def __init__( def __init__(
@ -113,7 +59,7 @@ class Model(LlmArchClass):
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop # text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1 vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
if SELECTED_ARCH == "llama": if cfg.model.arch_type == "llama":
super().__init__(config=LlamaConfig( super().__init__(config=LlamaConfig(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=d_model, hidden_size=d_model,
@ -134,7 +80,7 @@ class Model(LlmArchClass):
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False
)) ))
elif SELECTED_ARCH == "retnet": elif cfg.model.arch_type == "retnet":
super().__init__(config=RetNetConfig( super().__init__(config=RetNetConfig(
vocab_size=vocab_size, vocab_size=vocab_size,
decoder_embed_dim=d_model, decoder_embed_dim=d_model,
@ -156,12 +102,12 @@ class Model(LlmArchClass):
decoder_normalize_before=True, decoder_normalize_before=True,
)) ))
elif SELECTED_ARCH in ["mamba","mamba2"]: elif cfg.model.arch_type in ["mamba","mamba2"]:
super().__init__(config=MambaConfig( super().__init__(config=MambaConfig(
vocab_size=vocab_size, vocab_size=vocab_size,
d_model=d_model, d_model=d_model,
n_layer=n_layers*2, n_layer=n_layers*2,
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {}, ssm_cfg={"layer": "Mamba2", "chunk_size":64} if cfg.model.arch_type == "mamba2" else {},
fused_add_norm=True, fused_add_norm=True,
residual_in_fp32=True, residual_in_fp32=True,
)) ))
@ -181,7 +127,7 @@ class Model(LlmArchClass):
*args, *args,
**kwargs **kwargs
): ):
if SELECTED_ARCH in ["mamba","mamba2"]: if cfg.model.arch_type in ["mamba","mamba2"]:
kwargs["cg"] = True kwargs["cg"] = True
if "attention_mask" in kwargs: if "attention_mask" in kwargs:
@ -200,7 +146,7 @@ class Model(LlmArchClass):
*args, *args,
**kwargs, **kwargs,
): ):
if SELECTED_ARCH in ["mamba","mamba2"]: if cfg.model.arch_type in ["mamba","mamba2"]:
if "attention_mask" in kwargs: if "attention_mask" in kwargs:
kwargs.pop("attention_mask") kwargs.pop("attention_mask")
@ -371,7 +317,7 @@ def example_usage():
torch.save( { torch.save( {
'module': model.state_dict() 'module': model.state_dict()
}, f"./data/{SELECTED_ARCH}.pth" ) }, f"./data/{cfg.model.arch_type}.pth" )
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@ -427,7 +373,7 @@ def example_usage():
resp_list[i] = torch.stack( resp ).t() resp_list[i] = torch.stack( resp ).t()
for i, batch in enumerate(resp_list): for i, batch in enumerate(resp_list):
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device) _ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model() unload_model()
@ -456,7 +402,7 @@ def example_usage():
torch.save( { torch.save( {
'module': model.state_dict() 'module': model.state_dict()
}, f"./data/{SELECTED_ARCH}.pth" ) }, f"./data/{cfg.model.arch_type}.pth" )
#sample("init", 5) #sample("init", 5)
train() train()