From 2cef97e43f42eaf58f95943d89e8bc06cbd375ba Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 21 Nov 2024 23:08:43 -0600 Subject: [PATCH] cleanup --- vall_e/ext/__init__.py | 4 - vall_e/ext/retnet_hf/__init__.py | 3 - vall_e/ext/retnet_hf/configuration_retnet.py | 117 -- vall_e/ext/retnet_hf/modeling_retnet.py | 1464 ----------------- vall_e/ext/retnet_ts/__init__.py | 0 vall_e/ext/retnet_ts/config.py | 74 - vall_e/ext/retnet_ts/retnet.py | 746 --------- vall_e/models/arch/__init__.py | 14 +- vall_e/models/arch/mamba.py | 11 +- vall_e/models/arch/mamba_vasqu/__init__.py | 1 - vall_e/models/arch/mamba_vasqu/mamba2_hf.py | 4 - .../models/arch/retnet_syncdoth/__init__.py | 0 .../models/arch/retnet_syncdoth/retnet_hf.py | 196 --- .../models/arch/retnet_syncdoth/retnet_ts.py | 277 ---- vall_e/models/base.py | 138 +- vall_e/webui.py | 3 +- 16 files changed, 55 insertions(+), 2997 deletions(-) delete mode 100644 vall_e/ext/__init__.py delete mode 100644 vall_e/ext/retnet_hf/__init__.py delete mode 100644 vall_e/ext/retnet_hf/configuration_retnet.py delete mode 100644 vall_e/ext/retnet_hf/modeling_retnet.py delete mode 100644 vall_e/ext/retnet_ts/__init__.py delete mode 100644 vall_e/ext/retnet_ts/config.py delete mode 100644 vall_e/ext/retnet_ts/retnet.py delete mode 100644 vall_e/models/arch/mamba_vasqu/__init__.py delete mode 100644 vall_e/models/arch/mamba_vasqu/mamba2_hf.py delete mode 100755 vall_e/models/arch/retnet_syncdoth/__init__.py delete mode 100644 vall_e/models/arch/retnet_syncdoth/retnet_hf.py delete mode 100644 vall_e/models/arch/retnet_syncdoth/retnet_ts.py diff --git a/vall_e/ext/__init__.py b/vall_e/ext/__init__.py deleted file mode 100644 index 03ca57d..0000000 --- a/vall_e/ext/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# from https://github.com/syncdoth/RetNet/ - -# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system -# this is included because torchscale's implementation recently changed and I don't want to keep maintaining a fork \ No newline at end of file diff --git a/vall_e/ext/retnet_hf/__init__.py b/vall_e/ext/retnet_hf/__init__.py deleted file mode 100644 index fcac9b7..0000000 --- a/vall_e/ext/retnet_hf/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# from https://github.com/syncdoth/RetNet/ - -# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system \ No newline at end of file diff --git a/vall_e/ext/retnet_hf/configuration_retnet.py b/vall_e/ext/retnet_hf/configuration_retnet.py deleted file mode 100644 index d409c39..0000000 --- a/vall_e/ext/retnet_hf/configuration_retnet.py +++ /dev/null @@ -1,117 +0,0 @@ -from dataclasses import dataclass -import json - -from transformers.configuration_utils import PretrainedConfig - - -def load_config_from_json(config_file): - with open(config_file, 'r') as f: - config = json.loads(f.read()) - config = RetNetConfig.from_dict(config) - return config - - -@dataclass -class RetNetConfig(PretrainedConfig): - model_type = "retnet" - initializer_range: float = 0.02 - activation_fn: str = "gelu" - dropout: float = 0.0 # dropout probability - activation_dropout: float = 0.0 # dropout probability after activation in FFN. - drop_path_rate: float = 0.0 - decoder_embed_dim: int = 768 # decoder embedding dimension - decoder_value_embed_dim: int = 1280 # decoder value embedding dimension - decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN - decoder_layers: int = 12 # num decoder layers - decoder_retention_heads: int = 3 # num decoder retention heads - decoder_normalize_before: bool = True # apply layernorm before each decoder block - layernorm_embedding: bool = False # add layernorm to embedding - no_scale_embedding: bool = True # if True, dont scale embeddings - recurrent_chunk_size: int = 512 - use_lm_decay: bool = False - use_glu: bool = True # use GLU instead of FFN - z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4 - deepnorm: bool = False - subln: bool = True - use_ffn_rms_norm: bool = False - layernorm_eps: float = 1e-6 - tie_word_embeddings: bool = False - - def __init__( - self, - vocab_size: int = 50257, - initializer_range: float = 0.02, - is_decoder: bool = True, - pad_token_id: int = 0, - eos_token_id: int = 0, - output_retentions: bool = False, - use_cache: bool = True, - forward_impl: str = 'parallel', - activation_fn: str = "gelu", - dropout: float = 0.0, # dropout probability - activation_dropout: float = 0.0, # dropout probability after activation in FFN. - drop_path_rate: float = 0.0, - decoder_embed_dim: int = 768, # decoder embedding dimension - decoder_value_embed_dim: int = 1280, # decoder value embedding dimension - decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN - decoder_layers: int = 12, # num decoder layers - decoder_retention_heads: int = 3, # num decoder retention heads - decoder_normalize_before: bool = True, # apply layernorm before each decoder block - layernorm_embedding: bool = False, # add layernorm to embedding - no_scale_embedding: bool = True, # if True, dont scale embeddings - recurrent_chunk_size: int = 512, - use_glu: bool = True, # use GLU instead of FFN - z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4 - use_lm_decay: bool = False, - deepnorm: bool = True, - subln: bool = True, - use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN - layernorm_eps: float = 1e-6, - tie_word_embeddings: bool = False, - **kwargs): - self.vocab_size = vocab_size - self.initializer_range = initializer_range - self.output_retentions = output_retentions - self.use_lm_decay = use_lm_decay - self.use_glu = use_glu - self.z_loss_coeff = z_loss_coeff - # size related - self.decoder_embed_dim = decoder_embed_dim - self.decoder_value_embed_dim = decoder_value_embed_dim - self.decoder_retention_heads = decoder_retention_heads - self.decoder_ffn_embed_dim = decoder_ffn_embed_dim - self.decoder_layers = decoder_layers - # normalization related - self.decoder_normalize_before = decoder_normalize_before - self.activation_fn = activation_fn - self.dropout = dropout - self.drop_path_rate = drop_path_rate - self.activation_dropout = activation_dropout - self.no_scale_embedding = no_scale_embedding - self.layernorm_embedding = layernorm_embedding - self.deepnorm = deepnorm - self.subln = subln - self.use_ffn_rms_norm = use_ffn_rms_norm - self.layernorm_eps = layernorm_eps - # Blockwise - self.recurrent_chunk_size = recurrent_chunk_size - self.forward_impl = forward_impl - - if self.deepnorm: - self.decoder_normalize_before = False - self.subln = False - if self.subln: - self.decoder_normalize_before = True - self.deepnorm = False - - super().__init__(is_decoder=is_decoder, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - use_cache=use_cache, - tie_word_embeddings=tie_word_embeddings, - **kwargs) - - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) diff --git a/vall_e/ext/retnet_hf/modeling_retnet.py b/vall_e/ext/retnet_hf/modeling_retnet.py deleted file mode 100644 index 1f9730f..0000000 --- a/vall_e/ext/retnet_hf/modeling_retnet.py +++ /dev/null @@ -1,1464 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from timm.models.layers import drop_path -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -try: - from transformers import top_k_top_p_filtering -except Exception as e: - pass -try: - from transformers.generation.utils import top_k_top_p_filtering -except Exception as e: - pass -from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - -try: - from apex.normalization import FusedLayerNorm as LayerNorm -except ModuleNotFoundError: - from torch.nn import LayerNorm - -from .configuration_retnet import RetNetConfig - -logger = logging.get_logger(__name__) - - -# helper functions -def split_heads(tensors, bsz, seqlen, num_heads): - assert isinstance(tensors, (tuple, list)) - return [x.view(bsz, seqlen, num_heads, -1).transpose(1, 2) for x in tensors] - - -def rotate_every_two(x): - x1 = x[:, :, :, ::2] - x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ - - -def theta_shift(x, sin, cos): - return (x * cos) + (rotate_every_two(x) * sin) - - -def get_activation_fn(activation): - if activation == "relu": - return F.relu - elif activation == "gelu": - return F.gelu - elif activation == "swish": - return F.silu - else: - raise NotImplementedError - - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): - super().__init__() - self.normalized_shape = dim - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.register_parameter("weight", None) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - if self.weight is not None: - output = output * self.weight - return output - - -class RetNetRelPos(nn.Module): - def __init__(self, config: RetNetConfig): - super().__init__() - self.config = config - num_heads = config.decoder_retention_heads - - angle = 1.0 / ( - 10000 ** torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2) - ) - angle = angle.unsqueeze(-1).repeat(1, 2).flatten() - # decay (gamma) - if config.use_lm_decay: - # NOTE: alternative way described in the paper - s = torch.log(torch.tensor(1 / 32)) - e = torch.log(torch.tensor(1 / 512)) - decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,] - else: - decay = torch.log( - 1 - 2 ** (-5 - torch.arange(num_heads, dtype=torch.float)) - ) - self.register_buffer("angle", angle) - self.register_buffer("decay", decay) - self.recurrent_chunk_size = config.recurrent_chunk_size - - def forward( - self, - slen, - forward_impl="parallel", - recurrent_chunk_size=None, - retention_mask=None, - get_decay_scale=True, - ): - if forward_impl == "recurrent": - sin = torch.sin(self.angle * (slen - 1)) - cos = torch.cos(self.angle * (slen - 1)) - retention_rel_pos = ((sin, cos), self.decay.view(1, -1, 1, 1).exp()) - elif forward_impl == "chunkwise": - if recurrent_chunk_size is None: - recurrent_chunk_size = self.recurrent_chunk_size - index = torch.arange(slen).to(self.decay) - sin = torch.sin(index[:, None] * self.angle[None, :]) - cos = torch.cos(index[:, None] * self.angle[None, :]) - - block_index = torch.arange(recurrent_chunk_size).to(self.decay) - mask = torch.tril( - torch.ones(recurrent_chunk_size, recurrent_chunk_size) - ).to(self.decay) - mask = torch.masked_fill( - block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf") - ) - mask = torch.exp(mask * self.decay[:, None, None]) - mask = torch.nan_to_num(mask) - mask = mask.unsqueeze(0) # [1, h, t, t] - # TODO: need to handle retention_mask - # scaling - value_inner_decay = mask[:, :, -1] / mask[:, :, -1].sum( - dim=-1, keepdim=True - ) - value_inner_decay = value_inner_decay.unsqueeze(-1) - scale = mask.sum(dim=-1, keepdim=True).sqrt() - inner_mask = mask / scale - - cross_decay = torch.exp(self.decay * recurrent_chunk_size) - query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1)) - cross_decay = cross_decay[None, :, None, None] - query_inner_decay = query_inner_decay[None, :, :, None] / ( - scale / mask[:, :, -1].sum(dim=-1)[:, :, None, None] - ) - # decay_scale (used for kv cache) - if get_decay_scale: - decay_scale = self.compute_decay_scale(slen, retention_mask) - else: - decay_scale = None - retention_rel_pos = ( - (sin, cos), - ( - inner_mask, - cross_decay, - query_inner_decay, - value_inner_decay, - decay_scale, - ), - ) - else: # parallel - index = torch.arange(slen).to(self.decay) - sin = torch.sin(index[:, None] * self.angle[None, :]) - cos = torch.cos(index[:, None] * self.angle[None, :]) - mask = torch.tril(torch.ones(slen, slen)).to(self.decay) - mask = torch.masked_fill( - index[:, None] - index[None, :], ~mask.bool(), float("inf") - ) - mask = torch.exp(mask * self.decay[:, None, None]) - mask = torch.nan_to_num(mask) - mask = mask.unsqueeze(0) # [1, h, t, t] - if retention_mask is not None: - # this is required for left padding - mask = mask * retention_mask.float().view(-1, 1, 1, slen).to(mask) - - # scaling - mask = mask / mask.sum(dim=-1, keepdim=True).sqrt() - mask = torch.nan_to_num(mask, nan=0.0) - # decay_scale (used for kv cache) - if get_decay_scale: - decay_scale = self.compute_decay_scale(slen, retention_mask) - else: - decay_scale = None - # mask processing for intra decay - if retention_mask is not None: - max_non_zero = ( - torch.cumsum(retention_mask, dim=-1).max(dim=-1).indices - ) # [b,] - intra_decay = mask[range(mask.shape[0]), :, max_non_zero] - else: - intra_decay = mask[:, :, -1] - - retention_rel_pos = ((sin, cos), (mask, intra_decay, decay_scale)) - - return retention_rel_pos - - def compute_decay_scale(self, slen, retention_mask=None): - exponent = torch.arange(slen, device=self.decay.device).float() - decay_scale = self.decay.exp().view(-1, 1) ** exponent.view(1, -1) # [h, t] - if retention_mask is not None: - seqlen = retention_mask.sum(dim=-1) # [b,] - bsz = seqlen.size(0) - decay_scale = decay_scale.unsqueeze(0).repeat(bsz, 1, 1) # [b, h, t] - for i, pos in enumerate(seqlen): - # the formula for decay_scale is `sum(gamma^i) for i in [0, slen).` - # Since the retention_mask is 0 for padding, we can set the decay_scale - # to 0 for the padding positions. - decay_scale[i, :, pos.item() :] = 0 - else: - bsz = 1 - decay_scale = decay_scale.sum(-1).view(bsz, -1, 1, 1) # [b, h, 1, 1] - return decay_scale - - -class MultiScaleRetention(nn.Module): - def __init__( - self, - config: RetNetConfig, - gate_fn="swish", - use_bias=False, - tensor_parallel=False, - ): - super().__init__() - self.config = config - self.embed_dim = config.decoder_embed_dim - self.value_dim = config.decoder_value_embed_dim - self.num_heads = config.decoder_retention_heads - self.head_dim = self.value_dim // self.num_heads - self.key_dim = self.embed_dim // self.num_heads - self.scaling = self.key_dim**-0.5 - - self.gate_fn = get_activation_fn(activation=str(gate_fn)) - - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias) - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias) - self.v_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias) - self.g_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias) - - self.out_proj = nn.Linear(self.value_dim, self.embed_dim, bias=use_bias) - - self.group_norm = RMSNorm( - self.head_dim, eps=config.layernorm_eps, elementwise_affine=False - ) - self.reset_parameters() - - if tensor_parallel: - self.decay_proj = nn.Linear(self.num_heads, self.num_heads, bias=False) - else: - self.decay_proj = None - - def reset_parameters(self): - nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.out_proj.weight, gain=2**-1) - - def parallel_retention(self, q, k, v, decay_mask): - """ - q, # bsz * num_head * len * qk_dim - k, # bsz * num_head * len * qk_dim - v, # bsz * num_head * len * v_dim - decay_mask, # (1 or bsz) * num_head * len * len - """ - decay_mask, intra_decay, scale = decay_mask - # just return retention_rel_pos projected - # TODO: for shardformer - if self.decay_proj is not None: - decay_mask = self.decay_proj(decay_mask.transpose(-1, -3)).transpose(-3, -1) - - # [b, h, t, t] - retention = q @ k.transpose(-1, -2) # (scaled dot-product) - retention = retention * decay_mask - - # invariant after normalization - retention = retention / retention.detach().abs().sum( - dim=-1, keepdim=True - ).clamp(min=1, max=5e4) - - output = retention @ v # [b, h, t, v_dim / h] - output = output.transpose(1, 2) # [b, t, h, v_dim / h] - - if self.training: # skip cache - return output, None, retention - - if self.decay_proj is not None: - intra_decay = self.decay_proj(intra_decay.transpose(-1, -2)).transpose( - -2, -1 - ) - - # kv cache: [b, h, t, v_dim, qk_dim] - current_kv = k.unsqueeze(-2) * v.unsqueeze(-1) - intra_decay = intra_decay[:, :, :, None, None] # [b, h, t, 1, 1] - current_kv = (current_kv * intra_decay).sum(2) # [b, h, v_dim, qk_dim] - - cache = {"prev_key_value": current_kv, "scale": scale} - return output, cache, retention - - def recurrent_retention( - self, q, k, v, decay, past_key_value=None, retention_mask=None - ): - """ - q, k, v, # bsz * num_head * 1 * qkv_dim - past_key_value: - - "prev_key_value" # bsz * num_head * v_dim * qk_dim - - "scale" # (1 or bsz) * num_head * 1 * 1 - decay # (1 or bsz) * num_head * 1 * 1 - retention_mask # bsz * 1 - """ - if retention_mask is not None: - retention_mask = retention_mask.float().view(-1, 1, 1, 1).to(decay) - else: - retention_mask = torch.ones(k.size(0), 1, 1, 1).to(decay) - # (b, h, v_dim, qk_dim) - current_kv = k * v.transpose(-1, -2) * retention_mask - - if past_key_value is not None and "prev_key_value" in past_key_value: - prev_kv = past_key_value["prev_key_value"] - prev_scale = past_key_value["scale"] - scale = torch.where(retention_mask == 0, prev_scale, prev_scale * decay + 1) - # connect prev_kv and current_kv - # how much to decay prev_kv - decay_amount = prev_scale.sqrt() * decay / scale.sqrt() - decay_amount = torch.where(retention_mask == 0, 1, decay_amount) - prev_kv = prev_kv * decay_amount # decay prev_kv - current_kv = current_kv / scale.sqrt() # scale current_kv - current_kv = torch.nan_to_num( - current_kv, nan=0.0 - ) # remove nan, scale might be 0 - - current_kv = prev_kv + current_kv - else: - scale = torch.ones_like(decay) - # when retention_mask is 0 at the beginning, setting scale to 1 will - # make the first retention to use the padding incorrectly. Hence, - # setting it to 0 here. This is a little ugly, so we might want to - # change this later. TODO: improve - scale = torch.where(retention_mask == 0, torch.zeros_like(decay), scale) - - output = torch.sum(q * current_kv, dim=3).unsqueeze(1) # (b, 1, h, d_v) - - cache = {"prev_key_value": current_kv, "scale": scale} - return output, cache - - def chunkwise_retention(self, q, k, v, decay_mask): - """ - q, k, v, # bsz * num_head * seqlen * qkv_dim - past_key_value: - - "prev_key_value" # bsz * num_head * v_dim * qk_dim - - "scale" # (1 or bsz) * num_head * 1 * 1 - decay_mask, # 1 * num_head * chunk_size * chunk_size - cross_decay, # 1 * num_head * 1 * 1 - inner_decay, # 1 * num_head * chunk_size * 1 - """ - # TODO: not working properly - ( - decay_mask, - cross_decay, - query_inner_decay, - value_inner_decay, - decay_scale, - ) = decay_mask - bsz, _, tgt_len, _ = v.size() - chunk_len = decay_mask.size(-1) - assert tgt_len % chunk_len == 0 - num_chunks = tgt_len // chunk_len - - # [b, n_c, h, t_c, qkv_dim] - q = q.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose( - 1, 2 - ) - k = k.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose( - 1, 2 - ) - v = v.view(bsz, self.num_heads, num_chunks, chunk_len, self.head_dim).transpose( - 1, 2 - ) - - k_t = k.transpose(-1, -2) - - qk_mat = q @ k_t # [b, n_c, h, t_c, t_c] - qk_mat = qk_mat * decay_mask.unsqueeze(1) - inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1) - qk_mat = qk_mat / inner_scale - # [b, n_c, h, t_c, v_dim] - inner_output = torch.matmul(qk_mat, v) - - # reduce kv in one chunk - # [b, n_c, h, qk_dim, v_dim] - kv = k_t @ (v * value_inner_decay) - # kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim) - - kv_recurrent = [] - cross_scale = [] - kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v) - kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v) - - # accumulate kv by loop - for i in range(num_chunks): - kv_recurrent.append(kv_state / kv_scale) - cross_scale.append(kv_scale) - kv_state = kv_state * cross_decay + kv[:, i] - kv_scale = ( - kv_state.detach() - .abs() - .sum(dim=-2, keepdim=True) - .max(dim=-1, keepdim=True) - .values.clamp(min=1) - ) - - kv_recurrent = torch.stack(kv_recurrent, dim=1) - cross_scale = torch.stack(cross_scale, dim=1) - - all_scale = torch.maximum(inner_scale, cross_scale) - align_inner_scale = all_scale / inner_scale - align_cross_scale = all_scale / cross_scale - - cross_output = (q * query_inner_decay.unsqueeze(1)) @ kv_recurrent - output = inner_output / align_inner_scale + cross_output / align_cross_scale - output = output.transpose(2, 3) # [b, n_c, t_c, h, v_dim] - - cache = {"prev_key_value": kv_state.transpose(-2, -1), "scale": decay_scale} - return output, cache - - def forward( - self, - hidden_states: torch.Tensor, - rel_pos: Tuple[Tuple[torch.Tensor]], - retention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - forward_impl: str = "parallel", - output_retentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: - B, T, H = hidden_states.size() - (sin, cos), decay_mask = rel_pos - # projections - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - g = self.g_proj(hidden_states) - # multi-head - q, k, v = split_heads((q, k, v), B, T, self.num_heads) - k *= self.scaling # for scaled dot product - # rotate - # NOTE: theta_shift has bug with mps device. - qr = theta_shift(q, sin, cos) - kr = theta_shift(k, sin, cos) - - # retention - if forward_impl == "parallel": - retention_out, curr_kv, retention_weights = self.parallel_retention( - qr, kr, v, decay_mask - ) - elif forward_impl == "recurrent": - retention_out, curr_kv = self.recurrent_retention( - qr, - kr, - v, - decay_mask, - past_key_value=past_key_value, - retention_mask=retention_mask, - ) - elif forward_impl == "chunkwise": - retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask) - else: - raise ValueError(f"forward_impl {forward_impl} not supported.") - - # concaat heads - normed = self.group_norm(retention_out).reshape(B, T, self.value_dim) - # out gate & proj - out = self.gate_fn(g) * normed - out = self.out_proj(out) - - outputs = (out, curr_kv) - if output_retentions: - outputs += (retention_weights,) if forward_impl == "parallel" else (None,) - return outputs - - -class FeedForwardNetwork(nn.Module): - def __init__( - self, - embed_dim, - ffn_dim, - activation_fn, - dropout, - activation_dropout, - layernorm_eps, - subln=False, - use_rms_norm=False, - ): - super().__init__() - self.embed_dim = embed_dim - self.activation_fn = get_activation_fn(activation=str(activation_fn)) - self.activation_dropout_module = torch.nn.Dropout(activation_dropout) - self.dropout_module = torch.nn.Dropout(dropout) - self.fc1 = nn.Linear(self.embed_dim, ffn_dim) - self.fc2 = nn.Linear(ffn_dim, self.embed_dim) - if subln: - if use_rms_norm: - self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps) - else: - self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps) - else: - self.ffn_layernorm = None - - def reset_parameters(self): - self.fc1.reset_parameters() - self.fc2.reset_parameters() - if self.ffn_layernorm is not None: - self.ffn_layernorm.reset_parameters() - - def forward(self, x): - x_shape = x.shape - x = x.reshape(-1, x.size(-1)) - x = self.fc1(x) - x = self.activation_fn(x.float()).type_as(x) - x = self.activation_dropout_module(x) - if self.ffn_layernorm is not None: - x = self.ffn_layernorm(x) - x = self.fc2(x) - x = x.view(x_shape) - x = self.dropout_module(x) - return x - - -class GLU(nn.Module): - def __init__( - self, - embed_dim, - ffn_dim, - activation_fn, - dropout, - activation_dropout, - ): - super().__init__() - self.embed_dim = embed_dim - self.activation_fn = get_activation_fn(activation=str(activation_fn)) - self.activation_dropout_module = torch.nn.Dropout(activation_dropout) - self.dropout_module = torch.nn.Dropout(dropout) - self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False) - self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False) - self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False) - - def reset_parameters(self): - self.fc1.reset_parameters() - self.fc2.reset_parameters() - self.gate.reset_parameters() - - def forward(self, x): - x_shape = x.shape - x = x.reshape(-1, x.size(-1)) - g = self.gate(x) - x = self.fc1(x) - x = self.activation_fn(x.float()).type_as(x) * g - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = x.view(x_shape) - x = self.dropout_module(x) - return x - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - def extra_repr(self): - return "p={}".format(self.drop_prob) - - -class RetNetDecoderLayer(nn.Module): - def __init__(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False): - super().__init__() - self.config = config - self.embed_dim = config.decoder_embed_dim - self.dropout_module = torch.nn.Dropout(config.dropout) - - if config.drop_path_rate > 0: - drop_path_prob = np.linspace( - 0, config.drop_path_rate, config.decoder_layers - )[depth] - self.drop_path = DropPath(drop_path_prob) - else: - self.drop_path = None - - self.retention = MultiScaleRetention( - config, use_bias=False, tensor_parallel=tensor_parallel - ) - - self.normalize_before = config.decoder_normalize_before - - self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) - - self.ffn_dim = config.decoder_ffn_embed_dim - - self.ffn = self.build_ffn() - - self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) - - if config.deepnorm: - self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) - else: - self.alpha = 1.0 - - def build_ffn(self): - if self.config.use_glu: - return GLU( - self.embed_dim, - self.ffn_dim, - self.config.activation_fn, - self.config.dropout, - self.config.activation_dropout, - ) - else: - return FeedForwardNetwork( - self.embed_dim, - self.ffn_dim, - self.config.activation_fn, - self.config.dropout, - self.config.activation_dropout, - self.config.layernorm_eps, - self.config.subln, - self.config.use_ffn_rms_norm, - ) - - def residual_connection(self, x, residual): - return residual * self.alpha + x - - def forward( - self, - hidden_states: torch.Tensor, - retention_rel_pos: Tuple[Tuple[torch.Tensor]], - retention_mask: Optional[torch.Tensor] = None, - forward_impl: str = "parallel", - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_retentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: - residual = hidden_states - if self.normalize_before: - hidden_states = self.retention_layer_norm(hidden_states) - - msr_outs = self.retention( - hidden_states, - retention_rel_pos, - retention_mask=retention_mask, - past_key_value=past_key_value, - forward_impl=forward_impl, - output_retentions=output_retentions, - ) - hidden_states = msr_outs[0] - curr_kv = msr_outs[1] - - hidden_states = self.dropout_module(hidden_states) - - if self.drop_path is not None: - hidden_states = self.drop_path(hidden_states) - - hidden_states = self.residual_connection(hidden_states, residual) - if not self.normalize_before: - hidden_states = self.retention_layer_norm(hidden_states) - - residual = hidden_states - if self.normalize_before: - hidden_states = self.final_layer_norm(hidden_states) - - hidden_states = self.ffn(hidden_states) - - if self.drop_path is not None: - hidden_states = self.drop_path(hidden_states) - - hidden_states = self.residual_connection(hidden_states, residual) - if not self.normalize_before: - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states, curr_kv) - - if output_retentions: - outputs += (msr_outs[2],) - return outputs - - -class RetNetPreTrainedModel(PreTrainedModel): - # copied from LlamaPretrainedModel - config_class = RetNetConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["RetNetDecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - """ - Following original retnet, weights are already initialized in their own - ways within their own init. - """ - pass - # below is copied from LlamaPretrainedModel - # std = self.config.initializer_range - # if isinstance(module, nn.Linear): - # module.weight.data.normal_(mean=0.0, std=std) - # if module.bias is not None: - # module.bias.data.zero_() - # elif isinstance(module, nn.Embedding): - # module.weight.data.normal_(mean=0.0, std=std) - # if module.padding_idx is not None: - # module.weight.data[module.padding_idx].zero_() - - -@dataclass -class RetNetOutputWithPast(ModelOutput): - """ - class for RetNet model's outputs that may also contain a past key/values (to speed up sequential decoding). - - config: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, decoder_embed_dim)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - decoder_embed_dim)` is output. - past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim) - - "scale": shape=((1 or bsz) * num_head * 1 * 1) - - Contains pre-computed hidden-states (key and values in the multi-scale retention blocks) - that can be used (see `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Retentions weights, used for visualization. - - attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - retentions: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -class RetNetModel(RetNetPreTrainedModel): - def __init__( - self, - config: RetNetConfig, - embed_tokens: nn.Embedding = None, - tensor_parallel: bool = False, - ): - super().__init__(config) - self.config = config - - self.dropout_module = torch.nn.Dropout(config.dropout) - - self.embed_dim = config.decoder_embed_dim - self.embed_scale = ( - 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) - ) - - if embed_tokens is None: - embed_tokens = nn.Embedding( - config.vocab_size, config.decoder_embed_dim, config.pad_token_id - ) - self.embed_tokens = embed_tokens - - if config.layernorm_embedding: - self.layernorm_embedding = RMSNorm(self.embed_dim, eps=config.layernorm_eps) - else: - self.layernorm_embedding = None - - self.layers = nn.ModuleList([]) - - for i in range(config.decoder_layers): - self.layers.append( - RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel) - ) - - self.decoder_layers = len(self.layers) - - if config.decoder_normalize_before: - self.layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) - else: - self.layer_norm = None - - self.retnet_rel_pos = RetNetRelPos(config) - self.recurrent_chunk_size = config.recurrent_chunk_size - - if config.deepnorm: - init_scale = math.pow(8.0 * config.decoder_layers, 0.25) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.div_(init_scale) - - if config.subln and not config.use_glu: - init_scale = math.sqrt(math.log(config.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.mul_(init_scale) - - self.gradient_checkpointing = False - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward_embedding( - self, - input_ids, - forward_impl, - inputs_embeds=None, - past_key_values=None, - ): - # if past_key_values is not None: - if forward_impl == "recurrent": - input_ids = input_ids[:, -1:] - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - embed = self.embed_scale * inputs_embeds - - if self.layernorm_embedding is not None: - embed = self.layernorm_embedding(embed) - - embed = self.dropout_module(embed) - - return embed - - def forward( - self, - input_ids: torch.LongTensor = None, - retention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_retentions: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - return_dict: Optional[bool] = None, - forward_impl: Optional[str] = "parallel", - recurrent_chunk_size: Optional[int] = None, - retention_rel_pos: Optional[Tuple[torch.Tensor]] = None, - ) -> Union[Tuple, RetNetOutputWithPast]: - if output_retentions is None and output_attentions is not None: - output_retentions = output_attentions - output_retentions = ( - output_retentions - if output_retentions is not None - else self.config.output_retentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - # embed tokens - if inputs_embeds is None: - inputs_embeds = self.forward_embedding( - input_ids, forward_impl, inputs_embeds, past_key_values - ) - - if retention_mask is None and attention_mask is not None: - retention_mask = attention_mask - if retention_mask is not None and forward_impl == "recurrent": - retention_mask = retention_mask[:, -1:] - - hidden_states = inputs_embeds - - # handling chunking here - if recurrent_chunk_size is None: - recurrent_chunk_size = self.recurrent_chunk_size - need_pad_for_chunkwise = ( - forward_impl == "chunkwise" and seq_length % recurrent_chunk_size != 0 - ) - if need_pad_for_chunkwise: - padding_len = recurrent_chunk_size - seq_length % recurrent_chunk_size - slen = seq_length + padding_len - hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len)) - else: - slen = seq_length - # relative position - if retention_rel_pos is None: - retention_rel_pos = self.retnet_rel_pos( - slen, - forward_impl=forward_impl, - recurrent_chunk_size=recurrent_chunk_size, - retention_mask=retention_mask, - get_decay_scale=not self.training, - ) - - # start running through the decoder layers - all_hidden_states = () if output_hidden_states else None - all_retentions = () if output_retentions else None - # layers * [bsz, num_head, qk_dim, decoder_embed_dim] - next_decoder_cache = () if use_cache else None - - for idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_retentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), - hidden_states, - retention_rel_pos, - retention_mask, - forward_impl, - past_key_value, - use_reentrant=True, - ) - else: - layer_outputs = layer( - hidden_states, - retention_rel_pos, - retention_mask=retention_mask, - forward_impl=forward_impl, - past_key_value=past_key_value, - output_retentions=output_retentions, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[1],) - - if output_retentions: - all_retentions += (layer_outputs[2],) - - next_cache = next_decoder_cache if use_cache else None - - if need_pad_for_chunkwise: - hidden_states = hidden_states[:, :seq_length, :] - - if self.layer_norm is not None: - hidden_states = self.layer_norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_retentions] - if v is not None - ) - return RetNetOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - retentions=all_retentions, - attentions=all_retentions, - ) - - -@dataclass -class RetNetCausalLMOutputWithPast(ModelOutput): - """ - class for RetNet causal language model (or autoregressive) outputs. - - config: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim) - - "scale": shape=((1 or bsz) * num_head * 1 * 1) - - Contains pre-computed hidden-states (key and values in the multi-scale retention blocks) - that can be used (see `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Retentions weights, used for visualization. - - attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - retentions: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -class RetNetForCausalLM(RetNetPreTrainedModel): - def __init__( - self, - config: RetNetConfig, - embed_tokens: nn.Embedding = None, - tensor_parallel: bool = False, - ) -> None: - super().__init__(config) - self.model = RetNetModel( - config, embed_tokens=embed_tokens, tensor_parallel=tensor_parallel - ) - self.lm_head = nn.Linear( - config.decoder_embed_dim, config.vocab_size, bias=False - ) - # init here - torch.nn.init.normal_( - self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5 - ) - - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - retention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_retentions: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - forward_impl: Optional[str] = None, - recurrent_chunk_size: Optional[int] = None, - retention_rel_pos: Optional[Tuple[torch.Tensor]] = None, - ) -> Union[Tuple, RetNetCausalLMOutputWithPast]: - if output_retentions is None and output_attentions is not None: - output_retentions = output_attentions - output_retentions = ( - output_retentions - if output_retentions is not None - else self.config.output_retentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - forward_impl = ( - forward_impl if forward_impl is not None else self.config.forward_impl - ) - recurrent_chunk_size = ( - recurrent_chunk_size - if recurrent_chunk_size is not None - else self.config.recurrent_chunk_size - ) - - if retention_mask is None and attention_mask is not None: - retention_mask = attention_mask - - outputs = self.model( - input_ids, - retention_mask=retention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_retentions=output_retentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - forward_impl=forward_impl, - use_cache=use_cache, - recurrent_chunk_size=recurrent_chunk_size, - retention_rel_pos=retention_rel_pos, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if self.config.z_loss_coeff > 0: - # z_loss from PaLM paper - # z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits)) - z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean() - loss += self.config.z_loss_coeff * z_loss - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return RetNetCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - retentions=outputs.retentions, - attentions=outputs.retentions, - ) - - def _crop_past_key_values(model, past_key_values, maximum_length): - """Since retnet's kv do not have length, no need to crop. Just return""" - return past_key_values - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - forward_impl = kwargs.get("forward_impl", "parallel") - if past_key_values is not None: - forward_impl = "recurrent" - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "forward_impl": forward_impl, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: # dict - layer_past_kv = layer_past["prev_key_value"] # [b, h, v_dim / h, qk_dim] - layer_past_scale = layer_past["scale"] # [b, h, 1, 1] - if layer_past_scale.size(0) > 1: - # this means that retention_mask is not None, so the scale for - # each batch is different. We need to select the correct scale then. - # NOTE: during huggingface generate, it will generate attention_mask - # if it is None, so this linke will always be true. Still, having - # this line here for safety. - layer_past_scale = layer_past_scale.index_select(0, beam_idx) - reordered_past += ( - { - "prev_key_value": layer_past_kv.index_select(0, beam_idx), - "scale": layer_past_scale, - }, - ) - return reordered_past - - def sample_token(self, logit, do_sample=False, top_k=1, top_p=1.0, temperature=1.0): - if not do_sample: - return torch.argmax(logit, dim=-1, keepdim=True) - filtered = top_k_top_p_filtering(logit / temperature, top_k=top_k, top_p=top_p) - return torch.multinomial(torch.softmax(filtered, dim=-1), num_samples=1) - - @torch.inference_mode() - def custom_generate( - self, - input_ids: torch.LongTensor = None, - retention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - parallel_compute_prompt=True, - max_new_tokens=20, - bos_token_id=0, - eos_token_id=0, - do_sample=False, - top_k=0, - top_p=1.0, - temperature=1.0, - early_stopping=True, - ): - if retention_mask is None and attention_mask is not None: - retention_mask = attention_mask - - if input_ids is not None: - if input_ids.shape[1] == 1: - past_key_values = None - elif parallel_compute_prompt: - ret_mask = ( - retention_mask[:, :-1] if retention_mask is not None else None - ) - outputs = self( - input_ids[:, :-1], - retention_mask=ret_mask, - forward_impl="parallel", - return_dict=True, - use_cache=True, - ) - past_key_values = outputs.past_key_values - else: - past_key_values = None - for p_i in range(input_ids.shape[1] - 1): - ret_mask = ( - retention_mask[:, : p_i + 1] - if retention_mask is not None - else None - ) - outputs = self( - input_ids[:, : p_i + 1], - retention_mask=ret_mask, - forward_impl="recurrent", - past_key_values=past_key_values, - return_dict=True, - use_cache=True, - ) - past_key_values = outputs.past_key_values - - generated = input_ids - else: - generated = torch.tensor([[bos_token_id]]).to(self.lm_head.weight.device) - past_key_values = None - - for i in range(max_new_tokens): - outputs = self( - generated, - retention_mask=retention_mask, - forward_impl="recurrent", - past_key_values=past_key_values, - use_cache=True, - return_dict=True, - ) - logit = outputs.logits[:, -1, :] # [batch_size, vocab_size] - past_key_values = outputs.past_key_values - token = self.sample_token( - logit, - do_sample=do_sample, - top_k=top_k, - top_p=top_p, - temperature=temperature, - ) - generated = torch.cat([generated, token], dim=-1) - if retention_mask is not None: - retention_mask = torch.cat( - [retention_mask, torch.ones_like(token)], dim=-1 - ) - if early_stopping and (token == eos_token_id).all(): - break - return generated - - -class RetNetForSequenceClassification(RetNetPreTrainedModel): - def __init__(self, config, tensor_parallel=False): - super().__init__(config) - self.num_labels = config.num_labels - self.model = RetNetModel(config, tensor_parallel=tensor_parallel) - self.score = nn.Linear(config.decoder_embed_dim, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - retention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_retentions: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - forward_impl: Optional[str] = None, - recurrent_chunk_size: Optional[int] = None, - retention_rel_pos: Optional[Tuple[torch.Tensor]] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - if output_retentions is None and output_attentions is not None: - output_retentions = output_attentions - output_retentions = ( - output_retentions - if output_retentions is not None - else self.config.output_retentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - forward_impl = ( - forward_impl if forward_impl is not None else self.config.forward_impl - ) - recurrent_chunk_size = ( - recurrent_chunk_size - if recurrent_chunk_size is not None - else self.config.recurrent_chunk_size - ) - - if retention_mask is None and attention_mask is not None: - retention_mask = attention_mask - - outputs = self.model( - input_ids, - retention_mask=retention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_retentions=output_retentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - forward_impl=forward_impl, - use_cache=use_cache, - recurrent_chunk_size=recurrent_chunk_size, - retention_rel_pos=retention_rel_pos, - ) - - hidden_states = outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError( - "Cannot handle batch sizes > 1 if no padding token is defined." - ) - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = ( - torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 - ).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[ - torch.arange(batch_size, device=logits.device), sequence_lengths - ] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct( - pooled_logits.view(-1, self.num_labels), labels.view(-1) - ) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/vall_e/ext/retnet_ts/__init__.py b/vall_e/ext/retnet_ts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/vall_e/ext/retnet_ts/config.py b/vall_e/ext/retnet_ts/config.py deleted file mode 100644 index 2deea15..0000000 --- a/vall_e/ext/retnet_ts/config.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] - - -class RetNetConfig(object): - - def __init__(self, **kwargs): - self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) - self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280) - self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3) - self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280) - self.decoder_layers = kwargs.pop("decoder_layers", 12) - self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) - self.activation_fn = kwargs.pop("activation_fn", "gelu") - self.dropout = kwargs.pop("dropout", 0.0) - self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) - self.activation_dropout = kwargs.pop("activation_dropout", 0.0) - self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) - self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) - self.moe_freq = kwargs.pop("moe_freq", 0) - self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) - self.moe_expert_count = kwargs.pop("moe_expert_count", 0) - self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) - self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) - self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") - self.moe_normalize_gate_prob_before_dropping = kwargs.pop( - "moe_normalize_gate_prob_before_dropping", False) - self.use_xmoe = kwargs.pop("use_xmoe", False) - self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) - self.max_rel_pos = kwargs.pop("max_rel_pos", 0) - self.deepnorm = kwargs.pop("deepnorm", False) - self.subln = kwargs.pop("subln", True) - self.use_ffn_rms_norm = kwargs.pop("use_ffn_rms_norm", False) - self.use_glu = kwargs.pop("use_glu", True) - self.use_lm_decay = kwargs.pop("use_lm_decay", False) - self.z_loss_coeff = kwargs.pop("z_loss_coeff", 0.0) # TODO: 1e-4 - self.multiway = kwargs.pop("multiway", False) - self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", - False) - self.max_target_positions = kwargs.pop("max_target_positions", 1024) - self.no_output_layer = kwargs.pop("no_output_layer", True) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6) - # Blockwise - self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False) - self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512) - # Text - self.vocab_size = kwargs.pop("vocab_size", -1) - # Fairscale - self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) - self.fsdp = kwargs.pop("fsdp", False) - self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) - self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) - # token id - self.pad_token_id = kwargs.pop("pad_token_id", 0) - self.postprocessing() - - def postprocessing(self): - if self.deepnorm: - self.decoder_normalize_before = False - self.subln = False - if self.subln: - self.decoder_normalize_before = True - self.deepnorm = False - if self.use_xmoe: - self.moe_normalize_gate_prob_before_dropping = True - self.moe_second_expert_policy = "random" - assert self.moe_freq > 0 and self.moe_expert_count > 0 - - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) - self.postprocessing() \ No newline at end of file diff --git a/vall_e/ext/retnet_ts/retnet.py b/vall_e/ext/retnet_ts/retnet.py deleted file mode 100644 index d4d194b..0000000 --- a/vall_e/ext/retnet_ts/retnet.py +++ /dev/null @@ -1,746 +0,0 @@ -# Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -# from fairscale.nn import checkpoint_wrapper, wrap -from timm.models.layers import drop_path -from transformers.modeling_outputs import CausalLMOutputWithPast - -try: - from apex.normalization import FusedLayerNorm as LayerNorm -except ModuleNotFoundError: - from torch.nn import LayerNorm - - -def rotate_every_two(x): - x1 = x[:, :, :, ::2] - x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ - - -def theta_shift(x, sin, cos): - return (x * cos) + (rotate_every_two(x) * sin) - - -def get_activation_fn(activation): - if activation == "relu": - return F.relu - elif activation == "gelu": - return F.gelu - elif activation == "swish": - return F.silu - else: - raise NotImplementedError - - -class RMSNorm(nn.Module): - - def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): - super().__init__() - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.register_parameter('weight', None) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - if self.weight is not None: - output = output * self.weight - return output - - -class RetNetRelPos(nn.Module): - - def __init__(self, config): - super().__init__() - num_heads = config.decoder_retention_heads - - angle = 1.0 / (10000**torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2)) - angle = angle.unsqueeze(-1).repeat(1, 2).flatten() - if config.use_lm_decay: - # NOTE: alternative way described in the paper - s = torch.log(torch.tensor(1 / 32)) - e = torch.log(torch.tensor(1 / 512)) - decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,] - else: - decay = torch.log(1 - 2**(-5 - torch.arange(num_heads, dtype=torch.float))) - self.register_buffer("angle", angle) - self.register_buffer("decay", decay) - self.recurrent_chunk_size = config.recurrent_chunk_size - - def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False): - if activate_recurrent: - sin = torch.sin(self.angle * (slen - 1)) - cos = torch.cos(self.angle * (slen - 1)) - retention_rel_pos = ((sin, cos), self.decay.exp()) - elif chunkwise_recurrent: - index = torch.arange(slen).to(self.decay) - sin = torch.sin(index[:, None] * self.angle[None, :]) - cos = torch.cos(index[:, None] * self.angle[None, :]) - - block_index = torch.arange(self.recurrent_chunk_size).to(self.decay) - mask = torch.tril(torch.ones(self.recurrent_chunk_size, - self.recurrent_chunk_size)).to(self.decay) - mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), - float("inf")) - mask = torch.exp(mask * self.decay[:, None, None]) - mask = torch.nan_to_num(mask) - - value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True) - value_inner_decay = value_inner_decay.unsqueeze(-1) - scale = mask.sum(dim=-1, keepdim=True).sqrt() - inner_mask = mask / scale - - cross_decay = torch.exp(self.decay * self.recurrent_chunk_size) - query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1)) - query_inner_decay = query_inner_decay[:, :, None] / ( - scale / mask[:, -1].sum(dim=-1)[:, None, None]) - cross_decay = cross_decay[:, None, None] - retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, - value_inner_decay)) - else: - index = torch.arange(slen).to(self.decay) - sin = torch.sin(index[:, None] * self.angle[None, :]) - cos = torch.cos(index[:, None] * self.angle[None, :]) - mask = torch.tril(torch.ones(slen, slen)).to(self.decay) - mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf")) - mask = torch.exp(mask * self.decay[:, None, None]) - mask = torch.nan_to_num(mask) - mask = mask / mask.sum(dim=-1, keepdim=True).sqrt() - retention_rel_pos = ((sin, cos), mask) - - return retention_rel_pos - - -class MultiScaleRetention(nn.Module): - - def __init__( - self, - config, - embed_dim, - value_dim, - num_heads, - gate_fn="swish", - ): - super().__init__() - self.config = config - self.embed_dim = embed_dim - self.value_dim = value_dim - self.num_heads = num_heads - self.head_dim = self.value_dim // num_heads - self.key_dim = self.embed_dim // num_heads - self.scaling = self.key_dim**-0.5 - - self.gate_fn = get_activation_fn(activation=str(gate_fn)) - - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) - self.v_proj = nn.Linear(embed_dim, value_dim, bias=False) - self.g_proj = nn.Linear(embed_dim, value_dim, bias=False) - - self.out_proj = nn.Linear(value_dim, embed_dim, bias=False) - - self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False) - self.reset_parameters() - - def reset_parameters(self): - nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5) - nn.init.xavier_uniform_(self.out_proj.weight, gain=2**-1) - - def parallel_forward(self, qr, kr, v, mask): - bsz, tgt_len, embed_dim = v.size() - - vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - - qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len - qk_mat = qk_mat * mask - # invariant after normalization - qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4) - output = torch.matmul(qk_mat, vr) - output = output.transpose(1, 2) - return output - - def recurrent_forward(self, qr, kr, v, decay, incremental_state): - bsz = v.size(0) - - v = v.view(bsz, self.num_heads, self.head_dim, 1) - kv = kr * v - if "prev_key_value" in incremental_state: - prev_kv = incremental_state["prev_key_value"] - prev_scale = incremental_state["scale"] - scale = prev_scale * decay + 1 - kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view( - self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1) - # kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv - else: - scale = torch.ones_like(decay) - - incremental_state["prev_key_value"] = kv - incremental_state["scale"] = scale - - output = torch.sum(qr * kv, dim=3) - return output - - def chunk_recurrent_forward(self, qr, kr, v, inner_mask): - mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask - bsz, tgt_len, embed_dim = v.size() - chunk_len = mask.size(1) - num_chunks = tgt_len // chunk_len - - assert tgt_len % chunk_len == 0 - - qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2) - kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2) - v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3) - - kr_t = kr.transpose(-1, -2) - - qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len - qk_mat = qk_mat * mask - inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1) - qk_mat = qk_mat / inner_scale - inner_output = torch.matmul(qk_mat, - v) # bsz * num_heads * num_value_heads * chunk_len * head_dim - - # reduce kv in one chunk - kv = kr_t @ (v * value_inner_decay) - - kv_recurrent = [] - cross_scale = [] - kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v) - kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v) - - # accumulate kv by loop - for i in range(num_chunks): - kv_recurrent.append(kv_state / kv_scale) - cross_scale.append(kv_scale) - kv_state = kv_state * cross_decay + kv[:, i] - kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max( - dim=-1, keepdim=True).values.clamp(min=1) - - kv_recurrent = torch.stack(kv_recurrent, dim=1) - cross_scale = torch.stack(cross_scale, dim=1) - - all_scale = torch.maximum(inner_scale, cross_scale) - align_inner_scale = all_scale / inner_scale - align_cross_scale = all_scale / cross_scale - - cross_output = (qr * query_inner_decay) @ kv_recurrent - output = inner_output / align_inner_scale + cross_output / align_cross_scale - # output = inner_output / cross_scale + cross_output / inner_scale - - output = output.transpose(2, 3) - return output - - def forward(self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None): - bsz, tgt_len, _ = x.size() - (sin, cos), inner_mask = rel_pos - - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - g = self.g_proj(x) - - k *= self.scaling - q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) - k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) - - qr = theta_shift(q, sin, cos) - kr = theta_shift(k, sin, cos) - - if incremental_state is not None: - output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state) - elif chunkwise_recurrent: - output = self.chunk_recurrent_forward(qr, kr, v, inner_mask) - else: - output = self.parallel_forward(qr, kr, v, inner_mask) - - output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads) - - output = self.gate_fn(g) * output - - output = self.out_proj(output) - - return output - - -class FeedForwardNetwork(nn.Module): - - def __init__( - self, - embed_dim, - ffn_dim, - activation_fn, - dropout, - activation_dropout, - layernorm_eps, - subln=False, - use_rms_norm=False, - ): - super().__init__() - self.embed_dim = embed_dim - self.activation_fn = get_activation_fn(activation=str(activation_fn)) - self.activation_dropout_module = torch.nn.Dropout(activation_dropout) - self.dropout_module = torch.nn.Dropout(dropout) - self.fc1 = nn.Linear(self.embed_dim, ffn_dim) - self.fc2 = nn.Linear(ffn_dim, self.embed_dim) - if subln: - if use_rms_norm: - self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps) - else: - self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps) - else: - self.ffn_layernorm = None - - def reset_parameters(self): - self.fc1.reset_parameters() - self.fc2.reset_parameters() - if self.ffn_layernorm is not None: - self.ffn_layernorm.reset_parameters() - - def forward(self, x): - x_shape = x.shape - x = x.reshape(-1, x.size(-1)) - x = self.fc1(x) - x = self.activation_fn(x.float()).type_as(x) - x = self.activation_dropout_module(x) - if self.ffn_layernorm is not None: - x = self.ffn_layernorm(x) - x = self.fc2(x) - x = x.view(x_shape) - x = self.dropout_module(x) - return x - - -class GLU(nn.Module): - - def __init__( - self, - embed_dim, - ffn_dim, - activation_fn, - dropout, - activation_dropout, - ): - super().__init__() - self.embed_dim = embed_dim - self.activation_fn = get_activation_fn(activation=str(activation_fn)) - self.activation_dropout_module = torch.nn.Dropout(activation_dropout) - self.dropout_module = torch.nn.Dropout(dropout) - self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False) - self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False) - self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False) - - def reset_parameters(self): - self.fc1.reset_parameters() - self.fc2.reset_parameters() - self.gate.reset_parameters() - - def forward(self, x): - x_shape = x.shape - x = x.reshape(-1, x.size(-1)) - g = self.gate(x) - x = self.fc1(x) - x = self.activation_fn(x.float()).type_as(x) * g - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = x.view(x_shape) - x = self.dropout_module(x) - return x - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - def extra_repr(self): - return "p={}".format(self.drop_prob) - - -class RetNetDecoderLayer(nn.Module): - - def __init__( - self, - config, - depth, - ): - super().__init__() - self.config = config - self.embed_dim = config.decoder_embed_dim - self.dropout_module = torch.nn.Dropout(config.dropout) - - if config.drop_path_rate > 0: - drop_path_prob = np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth] - self.drop_path = DropPath(drop_path_prob) - else: - self.drop_path = None - - self.retention = MultiScaleRetention( - config, - self.embed_dim, - config.decoder_value_embed_dim, - config.decoder_retention_heads, - ) - - self.normalize_before = config.decoder_normalize_before - - self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) - - self.ffn_dim = config.decoder_ffn_embed_dim - - self.ffn = self.build_ffn() - - self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) - - if config.deepnorm: - self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) - else: - self.alpha = 1.0 - - def build_ffn(self): - if self.config.use_glu: - return GLU( - self.embed_dim, - self.ffn_dim, - self.config.activation_fn, - self.config.dropout, - self.config.activation_dropout, - ) - else: - return FeedForwardNetwork( - self.embed_dim, - self.ffn_dim, - self.config.activation_fn, - self.config.dropout, - self.config.activation_dropout, - self.config.layernorm_eps, - self.config.subln, - self.config.use_ffn_rms_norm, - ) - - def residual_connection(self, x, residual): - return residual * self.alpha + x - - def forward( - self, - x, - incremental_state=None, - chunkwise_recurrent=False, - retention_rel_pos=None, - ): - residual = x - if self.normalize_before: - x = self.retention_layer_norm(x) - - x = self.retention( - x, - incremental_state=incremental_state, - rel_pos=retention_rel_pos, - chunkwise_recurrent=chunkwise_recurrent, - ) - x = self.dropout_module(x) - - if self.drop_path is not None: - x = self.drop_path(x) - - x = self.residual_connection(x, residual) - if not self.normalize_before: - x = self.retention_layer_norm(x) - - residual = x - if self.normalize_before: - x = self.final_layer_norm(x) - - x = self.ffn(x) - - if self.drop_path is not None: - x = self.drop_path(x) - - x = self.residual_connection(x, residual) - if not self.normalize_before: - x = self.final_layer_norm(x) - - return x - - -class RetNetModel(nn.Module): - - def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.dropout_module = torch.nn.Dropout(config.dropout) - - embed_dim = config.decoder_embed_dim - self.embed_dim = embed_dim - self.embed_scale = 1.0 if config.no_scale_embedding else math.sqrt(embed_dim) - - self.embed_tokens = embed_tokens - - if (output_projection is None and not config.no_output_layer and config.vocab_size > 0): - self.output_projection = self.build_output_projection(config) - else: - self.output_projection = output_projection - - if config.layernorm_embedding: - self.layernorm_embedding = RMSNorm(embed_dim, eps=config.layernorm_eps) - else: - self.layernorm_embedding = None - - self.layers = nn.ModuleList([]) - - for i in range(config.decoder_layers): - self.layers.append(self.build_decoder_layer( - config, - depth=i, - )) - - self.num_layers = len(self.layers) - - if config.decoder_normalize_before: - self.layer_norm = RMSNorm(embed_dim, eps=config.layernorm_eps) - else: - self.layer_norm = None - - self.retnet_rel_pos = RetNetRelPos(config) - self.chunkwise_recurrent = config.chunkwise_recurrent - self.recurrent_chunk_size = config.recurrent_chunk_size - - if config.deepnorm: - init_scale = math.pow(8.0 * config.decoder_layers, 0.25) - for name, p in self.named_parameters(): - if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name): - p.data.div_(init_scale) - - if config.subln and not config.use_glu: - init_scale = math.sqrt(math.log(config.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name): - p.data.mul_(init_scale) - - def build_output_projection( - self, - config, - ): - if config.share_decoder_input_output_embed: - output_projection = torch.nn.Linear( - self.embed_tokens.weight.shape[1], - self.embed_tokens.weight.shape[0], - bias=False, - ) - output_projection.weight = self.embed_tokens.weight - else: - output_projection = torch.nn.Linear(config.decoder_embed_dim, - config.vocab_size, - bias=False) - torch.nn.init.normal_(output_projection.weight, - mean=0, - std=config.decoder_embed_dim**-0.5) - return output_projection - - def build_decoder_layer(self, config, depth): - layer = RetNetDecoderLayer( - config, - depth, - ) - # if config.checkpoint_activations: - # layer = checkpoint_wrapper(layer) - # if config.fsdp: - # layer = wrap(layer) - return layer - - def forward_embedding( - self, - tokens, - token_embedding=None, - incremental_state=None, - ): - if incremental_state is not None and not self.is_first_step(incremental_state): - tokens = tokens[:, -1:] - - if token_embedding is None: - token_embedding = self.embed_tokens(tokens) - - x = embed = self.embed_scale * token_embedding - - if self.layernorm_embedding is not None: - x = self.layernorm_embedding(x) - - x = self.dropout_module(x) - - return x, embed - - def is_first_step(self, incremental_state): - if incremental_state is None: - return False - return incremental_state.get("is_first_step", False) - - def forward(self, - prev_output_tokens, - incremental_state=None, - features_only=False, - token_embeddings=None): - # embed tokens - x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state) - is_first_step = self.is_first_step(incremental_state) - - if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0: - padding_len = self.recurrent_chunk_size - prev_output_tokens.size( - 1) % self.recurrent_chunk_size - slen = prev_output_tokens.size(1) + padding_len - x = F.pad(x, (0, 0, 0, padding_len)) - else: - slen = prev_output_tokens.size(1) - # relative position - retention_rel_pos = self.retnet_rel_pos(slen, - incremental_state is not None and not is_first_step, - chunkwise_recurrent=self.chunkwise_recurrent) - - # decoder layers - inner_states = [x] - - for idx, layer in enumerate(self.layers): - if incremental_state is None or is_first_step: - if is_first_step and incremental_state is not None: - if idx not in incremental_state: - incremental_state[idx] = {} - else: - if idx not in incremental_state: - incremental_state[idx] = {} - - x = layer( - x, - incremental_state[idx] if incremental_state is not None else None, - retention_rel_pos=retention_rel_pos, - chunkwise_recurrent=self.chunkwise_recurrent, - ) - inner_states.append(x) - - if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0: - x = x[:, :prev_output_tokens.size(1), :] - - if self.layer_norm is not None: - x = self.layer_norm(x) - - if not features_only: - x = self.output_layer(x) - - return x, { - "inner_states": inner_states, - "attn": None, - } - - def output_layer(self, features): - return self.output_projection(features) - - -class RetNetForCausalLM(nn.Module): - - def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs): - super().__init__(**kwargs) - assert config.vocab_size > 0, "you must specify vocab size" - if output_projection is None: - config.no_output_layer = False - if embed_tokens is None: - embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim, - config.pad_token_id) - - self.config = config - self.model = RetNetModel(config, - embed_tokens=embed_tokens, - output_projection=output_projection, - **kwargs) - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.model.output_projection - - def set_output_embeddings(self, new_embeddings): - self.model.output_projection = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - retention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_retentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - recurrent_chunk_size: Optional[int] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - outputs = self.model( - input_ids, - incremental_state=past_key_values, - features_only=False, - token_embeddings=inputs_embeds, - ) - - logits, inner_hidden_states = outputs[0], outputs[1]['inner_states'] - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if self.config.z_loss_coeff > 0: - # z_loss from PaLM paper - # z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits)) - z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean() - loss += self.config.z_loss_coeff * z_loss - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=past_key_values, - hidden_states=inner_hidden_states, - attentions=None, - ) \ No newline at end of file diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index 65a2a94..45b2212 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -15,6 +15,7 @@ except Exception as e: ERROR_ARCHES["retnet"] = e pass +""" try: from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS AVAILABLE_ARCHES.append("retnet-ts") @@ -28,6 +29,7 @@ try: except Exception as e: ERROR_ARCHES["retnet-hf"] = e pass +""" try: from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM @@ -50,6 +52,15 @@ try: except Exception as e: ERROR_ARCHES["mixtral"] = e +try: + from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config + AVAILABLE_ARCHES.append("mamba") + AVAILABLE_ARCHES.append("mamba2") +except Exception as e: + ERROR_ARCHES["mamba"] = e + ERROR_ARCHES["mamba2"] = e + +""" try: from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig AVAILABLE_ARCHES.append("mamba") @@ -62,4 +73,5 @@ try: from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF AVAILABLE_ARCHES.append("mamba2-hf") except Exception as e: - ERROR_ARCHES["mamba2-hf"] = e \ No newline at end of file + ERROR_ARCHES["mamba2-hf"] = e +""" \ No newline at end of file diff --git a/vall_e/models/arch/mamba.py b/vall_e/models/arch/mamba.py index e9ae498..78e20ce 100644 --- a/vall_e/models/arch/mamba.py +++ b/vall_e/models/arch/mamba.py @@ -1,3 +1,11 @@ + +from transformers.models.mamba.modeling_mamba import MambaModel +from transformers.models.mamba2.modeling_mamba2 import Mamba2Model + +from transformers.models.mamba.configuration_mamba import MambaConfig +from transformers.models.mamba2.configuration_mamba2 import Mamba2Config + +""" # https://github.com/state-spaces/mamba from torch.utils.checkpoint import checkpoint @@ -29,4 +37,5 @@ def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_ ) return hidden_states -MambaMixelModel.forward = MambaMixelModel_forward \ No newline at end of file +MambaMixelModel.forward = MambaMixelModel_forward +""" \ No newline at end of file diff --git a/vall_e/models/arch/mamba_vasqu/__init__.py b/vall_e/models/arch/mamba_vasqu/__init__.py deleted file mode 100644 index 0c20b1b..0000000 --- a/vall_e/models/arch/mamba_vasqu/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .mamba2_hf import * \ No newline at end of file diff --git a/vall_e/models/arch/mamba_vasqu/mamba2_hf.py b/vall_e/models/arch/mamba_vasqu/mamba2_hf.py deleted file mode 100644 index f285963..0000000 --- a/vall_e/models/arch/mamba_vasqu/mamba2_hf.py +++ /dev/null @@ -1,4 +0,0 @@ -# https://github.com/vasqu/mamba2-torch -# NOTE: edit `src/mamba2_torch/__init__.py` to remove reference to .src. because of how pip treats packages - -from mamba2_torch import Mamba2Model as Mamba2Model_HF, Mamba2Config as Mamba2Config_HF \ No newline at end of file diff --git a/vall_e/models/arch/retnet_syncdoth/__init__.py b/vall_e/models/arch/retnet_syncdoth/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/vall_e/models/arch/retnet_syncdoth/retnet_hf.py b/vall_e/models/arch/retnet_syncdoth/retnet_hf.py deleted file mode 100644 index 93fc532..0000000 --- a/vall_e/models/arch/retnet_syncdoth/retnet_hf.py +++ /dev/null @@ -1,196 +0,0 @@ -# https://github.com/syncdoth/RetNet/ -from ....ext.retnet_hf.configuration_retnet import RetNetConfig -from ....ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM - -# things we're overriding or required to override -from ....ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos - -import torch -import math -from typing import Dict, List, Optional, Tuple, Union - -# required to have compatibile LayerNorm -def FeedForwardNetwork_init( - self, - embed_dim, - ffn_dim, - activation_fn, - dropout, - activation_dropout, - layernorm_eps, - subln=True, - use_rms_norm=False, -): - super(FeedForwardNetwork, self).__init__() - self.embed_dim = embed_dim - self.activation_fn = get_activation_fn(activation=str(activation_fn)) - self.activation_dropout_module = torch.nn.Dropout(activation_dropout) - self.dropout_module = torch.nn.Dropout(dropout) - self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim) - self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim) - self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None - -FeedForwardNetwork.__init__ = FeedForwardNetwork_init - -def RetNetModel_init( - self, - config: RetNetConfig, - embed_tokens: torch.nn.Embedding = None, - tensor_parallel: bool = False, -): - super(RetNetDecoder, self).__init__(config) - self.config = config - - self.dropout_module = torch.nn.Dropout(config.dropout) - - self.embed_dim = config.decoder_embed_dim - self.embed_scale = ( - 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) - ) - - if embed_tokens is None and config.vocab_size: - embed_tokens = torch.nn.Embedding( - config.vocab_size, config.decoder_embed_dim, config.pad_token_id - ) - self.embed_tokens = embed_tokens - - if config.layernorm_embedding: - self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layernorm_embedding = None - - self.layers = torch.nn.ModuleList([]) - - for i in range(config.decoder_layers): - self.layers.append( - RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel) - ) - - self.decoder_layers = len(self.layers) - - if config.decoder_normalize_before: - self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layer_norm = None - - self.retnet_rel_pos = RetNetRelPos(config) - self.recurrent_chunk_size = config.recurrent_chunk_size - - if config.deepnorm: - init_scale = math.pow(8.0 * config.decoder_layers, 0.25) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.div_(init_scale) - - if config.subln and not config.use_glu: - init_scale = math.sqrt(math.log(config.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.mul_(init_scale) - - self.gradient_checkpointing = True - self.post_init() - -RetNetDecoder.__init__ = RetNetModel_init - -# restores bias in our FFNs -def RetNetDecoderLayer_init(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False): - super(RetNetDecoderLayer, self).__init__() - self.config = config - self.embed_dim = config.decoder_embed_dim - self.dropout_module = torch.nn.Dropout(config.dropout) - - if config.drop_path_rate > 0: - drop_path_prob = np.linspace( - 0, config.drop_path_rate, config.decoder_layers - )[depth] - self.drop_path = DropPath(drop_path_prob) - else: - self.drop_path = None - - self.retention = MultiScaleRetention( - config, use_bias=True, tensor_parallel=tensor_parallel - ) - - self.normalize_before = config.decoder_normalize_before - - self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - - self.ffn_dim = config.decoder_ffn_embed_dim - - self.ffn = self.build_ffn() - - self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - - if config.deepnorm: - self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) - else: - self.alpha = 1.0 - -RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init -# fixes backwards when using te's autocast -def MultiScaleRetention_forward( - self, - hidden_states: torch.Tensor, - rel_pos: Tuple[Tuple[torch.Tensor]], - retention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - forward_impl: str = "parallel", - output_retentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: - B, T, H = hidden_states.size() - (sin, cos), decay_mask = rel_pos - # projections - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) * self.scaling # for scaled dot product - v = self.v_proj(hidden_states) - g = self.g_proj(hidden_states) - # multi-head - q, k, v = split_heads((q, k, v), B, T, self.num_heads) - - # rotate - # NOTE: theta_shift has bug with mps device. - qr = theta_shift(q, sin, cos) - kr = theta_shift(k, sin, cos) - - # retention - if forward_impl == "parallel": - retention_out, curr_kv, retention_weights = self.parallel_retention( - qr, kr, v, decay_mask - ) - elif forward_impl == "recurrent": - retention_out, curr_kv = self.recurrent_retention( - qr, - kr, - v, - decay_mask, - past_key_value=past_key_value, - retention_mask=retention_mask, - ) - elif forward_impl == "chunkwise": - retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask) - else: - raise ValueError(f"forward_impl {forward_impl} not supported.") - - # concaat heads - normed = self.group_norm(retention_out).reshape(B, T, self.value_dim) - # out gate & proj - out = self.gate_fn(g) * normed - out = self.out_proj(out) - - outputs = (out, curr_kv) - if output_retentions: - outputs += (retention_weights,) if forward_impl == "parallel" else (None,) - return outputs - -MultiScaleRetention.forward = MultiScaleRetention_forward \ No newline at end of file diff --git a/vall_e/models/arch/retnet_syncdoth/retnet_ts.py b/vall_e/models/arch/retnet_syncdoth/retnet_ts.py deleted file mode 100644 index 093f7fe..0000000 --- a/vall_e/models/arch/retnet_syncdoth/retnet_ts.py +++ /dev/null @@ -1,277 +0,0 @@ -# https://github.com/syncdoth/RetNet/ -from ....ext.retnet_ts.config import RetNetConfig -from ....ext.retnet_ts.retnet import RetNetModel as RetNetDecoder - -# things we're overriding or required to override -from ....ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos - -import torch -import math -from typing import Dict, List, Optional, Tuple, Union - -from torch.utils.checkpoint import checkpoint - -# required to have compatibile LayerNorm -def FeedForwardNetwork_init( - self, - embed_dim, - ffn_dim, - activation_fn, - dropout, - activation_dropout, - layernorm_eps, - subln=True, - use_rms_norm=False, -): - super(FeedForwardNetwork, self).__init__() - self.embed_dim = embed_dim - self.activation_fn = get_activation_fn(activation=str(activation_fn)) - self.activation_dropout_module = torch.nn.Dropout(activation_dropout) - self.dropout_module = torch.nn.Dropout(dropout) - self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim) - self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim) - self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None - -FeedForwardNetwork.__init__ = FeedForwardNetwork_init - -# removes embed_tokens -def RetNetModel_init( - self, config, embed_tokens=None, output_projection=None, **kwargs -): - super(RetNetDecoder, self).__init__(**kwargs) - self.config = config - - self.dropout_module = torch.nn.Dropout(config.dropout) - - self.embed_dim = config.decoder_embed_dim - self.embed_scale = ( - 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) - ) - - if embed_tokens is None and config.vocab_size: - embed_tokens = torch.nn.Embedding( - config.vocab_size, config.decoder_embed_dim, config.pad_token_id - ) - self.embed_tokens = embed_tokens - - if (output_projection is None and not config.no_output_layer and config.vocab_size > 0): - self.output_projection = self.build_output_projection(config) - else: - self.output_projection = output_projection - - if config.layernorm_embedding: - self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layernorm_embedding = None - - self.layers = torch.nn.ModuleList([]) - - for i in range(config.decoder_layers): - layer = self.build_decoder_layer( - config, - depth=i, - ) - """ - if config.checkpoint_activations: - layer = checkpoint_wrapper(layer) - """ - self.layers.append(layer) - - self.num_layers = len(self.layers) - - if config.decoder_normalize_before: - self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layer_norm = None - - self.retnet_rel_pos = RetNetRelPos(config) - self.chunkwise_recurrent = config.chunkwise_recurrent - self.recurrent_chunk_size = config.recurrent_chunk_size - - if config.deepnorm: - init_scale = math.pow(8.0 * config.decoder_layers, 0.25) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.div_(init_scale) - - if config.subln and not config.use_glu: - init_scale = math.sqrt(math.log(config.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.mul_(init_scale) - - self.gradient_checkpointing = True - -RetNetDecoder.__init__ = RetNetModel_init - -# restores bias in our FFNs -def RetNetDecoderLayer_init( - self, - config, - depth, - use_bias=True -): - super(RetNetDecoderLayer, self).__init__() - self.config = config - self.embed_dim = config.decoder_embed_dim - self.dropout_module = torch.nn.Dropout(config.dropout) - - if config.drop_path_rate > 0: - drop_path_prob = np.linspace( - 0, config.drop_path_rate, config.decoder_layers - )[depth] - self.drop_path = DropPath(drop_path_prob) - else: - self.drop_path = None - - self.retention = MultiScaleRetention( - config, - self.embed_dim, - config.decoder_value_embed_dim, - config.decoder_retention_heads, - use_bias=use_bias - ) - - self.normalize_before = config.decoder_normalize_before - - self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - - self.ffn_dim = config.decoder_ffn_embed_dim - - self.ffn = self.build_ffn() - - self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - - if config.deepnorm: - self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) - else: - self.alpha = 1.0 - -def RetNetDecoderLayer_forward( - self, - x, - incremental_state=None, - chunkwise_recurrent=False, - retention_rel_pos=None, -): - residual = x - if self.normalize_before: - x = self.retention_layer_norm(x) - - if x.requires_grad and self.config.checkpoint_activations: - x = checkpoint( - self.retention, - x, - use_reentrant=False, - incremental_state=incremental_state, - rel_pos=retention_rel_pos, - chunkwise_recurrent=chunkwise_recurrent, - ) - else: - x = self.retention( - x, - incremental_state=incremental_state, - rel_pos=retention_rel_pos, - chunkwise_recurrent=chunkwise_recurrent, - ) - x = self.dropout_module(x) - - if self.drop_path is not None: - x = self.drop_path(x) - - x = self.residual_connection(x, residual) - if not self.normalize_before: - x = self.retention_layer_norm(x) - - residual = x - if self.normalize_before: - x = self.final_layer_norm(x) - - x = self.ffn(x) - - if self.drop_path is not None: - x = self.drop_path(x) - - x = self.residual_connection(x, residual) - if not self.normalize_before: - x = self.final_layer_norm(x) - - return x - -RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init -RetNetDecoderLayer.forward = RetNetDecoderLayer_forward -# fixes backwards when using te's autocast -def MultiScaleRetention_init( - self, - config, - embed_dim, - value_dim, - num_heads, - gate_fn="swish", - use_bias=True, -): - super(MultiScaleRetention, self).__init__() - self.config = config - self.embed_dim = embed_dim - self.value_dim = value_dim - self.num_heads = num_heads - self.head_dim = self.value_dim // num_heads - self.key_dim = self.embed_dim // num_heads - self.scaling = self.key_dim**-0.5 - - self.gate_fn = get_activation_fn(activation=str(gate_fn)) - - self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias) - self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias) - self.v_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias) - self.g_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias) - - self.out_proj = torch.nn.Linear(value_dim, embed_dim, bias=use_bias) - - self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False) - self.reset_parameters() - -def MultiScaleRetention_forward( - self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None -) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: - bsz, tgt_len, _ = x.size() - (sin, cos), inner_mask = rel_pos - - q = self.q_proj(x) - k = self.k_proj(x) * self.scaling - v = self.v_proj(x) - g = self.g_proj(x) - - q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) - k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) - - qr = theta_shift(q, sin, cos) - kr = theta_shift(k, sin, cos) - - if incremental_state is not None: - output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state) - elif chunkwise_recurrent: - output = self.chunk_recurrent_forward(qr, kr, v, inner_mask) - else: - output = self.parallel_forward(qr, kr, v, inner_mask) - - output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads) - - output = self.gate_fn(g) * output - - output = self.out_proj(output) - - return output - -MultiScaleRetention.__init__ = MultiScaleRetention_init -MultiScaleRetention.forward = MultiScaleRetention_forward \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 30e475d..d2fefbf 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -704,87 +704,26 @@ class Base(nn.Module): )) self.model = RetNetDecoder(RetNetConfig(**kwargs)) - elif self.arch_type == "retnet-hf": - kwargs = dict( + elif self.arch_type in ["mamba2"]: + self.model = Mamba2Model(Mamba2Config( vocab_size=n_resp_tokens, - decoder_embed_dim=d_model, - decoder_value_embed_dim =d_model * 2, - decoder_retention_heads=n_heads, - decoder_ffn_embed_dim=d_model * 4, - decoder_layers=n_layers, - dropout=p_dropout if training else 0.0, - checkpoint_activations=self.gradient_checkpointing, - activation_fn="gelu", - use_glu=False, # self.version >= 3, - - recurrent_chunk_size=self.causal_size if self.causal else 0, - decoder_normalize_before=True, - - deepnorm=False, - subln=True, - ) - - self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs)) - - if self.gradient_checkpointing and not self.model.gradient_checkpointing: - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( - use_reentrant=False - )) - elif self.arch_type == "bitnet": - self.model = BitNetTransformer( - num_tokens=n_resp_tokens, - dim=d_model, - depth=n_layers, - heads=n_heads, - ff_mult=4, - gradient_checkpointing=self.gradient_checkpointing, - ) - elif self.arch_type in ["mamba","mamba2"]: - self.model = MambaMixelModel( - vocab_size=n_resp_tokens, - d_model=d_model, - n_layer=n_layers*2, - d_intermediate=0, #d_model*2, - ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if self.arch_type == "mamba2" else {}, - rms_norm=True, - fused_add_norm=True, + hidden_size=d_model, + expand=2, + num_hidden_layers=n_layers*2, residual_in_fp32=True, - #attn_layer_idx=attn_layer_idx, - #attn_cfg=attn_cfg, - #initializer_cfg=initializer_cfg, - ) - self.model.gradient_checkpointing = self.gradient_checkpointing - elif self.arch_type in ["mamba2-hf"]: - self.model = Mamba2Model_HF(Mamba2Config_HF( - vocab_size=n_resp_tokens, - hidden_size=d_model, - max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds - expand=4, - num_hidden_layers=n_layers, - is_encoder_decoder=False, - is_decoder=True, - use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it) - residual_in_fp32=True, # breaks for AMP inference )) if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False )) - elif self.arch_type == "mmfreelm": - self.model = HGRNBitModel(HGRNBitConfig( + elif self.arch_type in ["mamba"]: + self.model = MambaModel(MambaConfig( vocab_size=n_resp_tokens, hidden_size=d_model, - max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds - intermediate_size=d_model*4, - num_hidden_layers=n_layers, - num_heads=n_heads, - #hidden_act="gelu", - #is_encoder_decoder=False, - #is_decoder=True, - attn_mode=hf_attention, - #gradient_checkpointing=self.gradient_checkpointing, + expand=2, + num_hidden_layers=n_layers*2, + residual_in_fp32=True, )) - if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False @@ -795,7 +734,6 @@ class Base(nn.Module): if hasattr( self.model, "embeddings" ): del self.model.embeddings - if not split_classifiers: self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifiers = None @@ -854,8 +792,8 @@ class Base(nn.Module): # HF transformer derived model if self.arch_type in ["llama", "mistral", "mixtral"]: kwargs = dict( - #attention_mask=m, inputs_embeds=x, + attention_mask=m, past_key_values=state, position_ids=position_ids, use_cache=False, # not self.training, @@ -901,46 +839,31 @@ class Base(nn.Module): x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) if _ is not None and "l_aux" in _ and self.n_experts > 1: aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001 - elif self.arch_type == "retnet-hf": - first = state is None or len(state) == 0 - - kwargs = dict( - attention_mask=m, - inputs_embeds=x if first else x[:, -1, :].unsqueeze(1), - past_key_values=None if first else state, - use_cache=True, - forward_impl='parallel' if first else 'recurrent', - return_dict=True, - ) - - out = self.model(**kwargs) - x = out.last_hidden_state - if state is not None: - state = out.past_key_values elif self.arch_type in ["mamba","mamba2"]: - x = self.model( hidden_states=x ) - elif self.arch_type == "mamba2-hf": - first = state is None or len(state) == 0 - kwargs = dict( + #attention_mask=m, inputs_embeds=x, - cache_params=state, + #cache_params=state, + use_cache=False, # not self.training, + #position_ids=position_ids, + #output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=True, ) - out = self.model(**kwargs) - x = out.last_hidden_state + output = self.model(**kwargs) + x = output["last_hidden_state"] + + # to-do: figure out why KV caching doesn't work + #if not self.training: if state is not None: - state = out.cache_params - elif self.arch_type == "bitnet": - x = self.model(x) - elif self.arch_type == "mmfreelm": - x = self.model( - attention_mask=m, - inputs_embeds=x, - ) + state = output["cache_params"] - x = x[0] + if output_attentions: + attentions = output["attentions"] + + if output_hidden_states: + hidden_states = output["hidden_states"] # process it into a format that I like if output_hidden_states: @@ -1559,7 +1482,6 @@ class Base(nn.Module): x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, mask = list_to_tensor(x_list) - m = mask.unsqueeze(dim=-1) training = self.training device = x.device @@ -1584,8 +1506,10 @@ class Base(nn.Module): # pad mask shape[2] = 1 - padding = torch.zeros(shape, dtype=x.dtype, device=x.device) + padding = torch.zeros(shape[:2], dtype=x.dtype, device=x.device) mask = torch.cat([mask, padding], dim=1) + + m = mask.unsqueeze(dim=-1) # needs to be done here as we still have our raw inputs position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None diff --git a/vall_e/webui.py b/vall_e/webui.py index 6fb666b..a9b45c3 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -99,8 +99,7 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data continue configs.append( sft ) - if is_windows: - configs = [ str(p) for p in configs ] + configs = [ str(p) for p in configs ] return configs