diff --git a/vall_e/config.py b/vall_e/config.py index aa8d075..d41930f 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -179,6 +179,7 @@ class Model: interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results) p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training + attention: str = "eager" # or flash_attention_2 def get(self, name=None): return [ self ] if not name or self.name == name else [] @@ -528,6 +529,7 @@ class Config(_Config): dataset: Dataset = field(default_factory=lambda: Dataset) model: Model = field(default_factory=lambda: Model) + models: dict | list | None = None # deprecated hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) @@ -576,7 +578,10 @@ class Config(_Config): def format( self ): self.dataset = Dataset(**self.dataset) - self.model = Model(**self.model) + if self.models is not None: + self.model = Model(**next(iter(self.models))) + else: + self.model = Model(**self.model) self.hyperparameters = Hyperparameters(**self.hyperparameters) self.evaluation = Evaluation(**self.evaluation) self.trainer = Trainer(**self.trainer) diff --git a/vall_e/ext/__init__.py b/vall_e/ext/__init__.py index e69de29..03ca57d 100644 --- a/vall_e/ext/__init__.py +++ b/vall_e/ext/__init__.py @@ -0,0 +1,4 @@ +# from https://github.com/syncdoth/RetNet/ + +# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system +# this is included because torchscale's implementation recently changed and I don't want to keep maintaining a fork \ No newline at end of file diff --git a/vall_e/ext/retnet_hf/modeling_retnet.py b/vall_e/ext/retnet_hf/modeling_retnet.py index ade5bfb..4a7fcdc 100644 --- a/vall_e/ext/retnet_hf/modeling_retnet.py +++ b/vall_e/ext/retnet_hf/modeling_retnet.py @@ -9,7 +9,14 @@ import torch.utils.checkpoint from timm.models.layers import drop_path from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import top_k_top_p_filtering +try: + from transformers import top_k_top_p_filtering +except Exception as e: + pass +try: + from transformers.generation.utils import top_k_top_p_filtering +except Exception as e: + pass from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging diff --git a/vall_e/ext/retnet_ts/__init__.py b/vall_e/ext/retnet_ts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vall_e/ext/retnet_ts/config.py b/vall_e/ext/retnet_ts/config.py new file mode 100644 index 0000000..2deea15 --- /dev/null +++ b/vall_e/ext/retnet_ts/config.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + + +class RetNetConfig(object): + + def __init__(self, **kwargs): + self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) + self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280) + self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3) + self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280) + self.decoder_layers = kwargs.pop("decoder_layers", 12) + self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) + self.activation_fn = kwargs.pop("activation_fn", "gelu") + self.dropout = kwargs.pop("dropout", 0.0) + self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) + self.activation_dropout = kwargs.pop("activation_dropout", 0.0) + self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) + self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) + self.moe_freq = kwargs.pop("moe_freq", 0) + self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) + self.moe_expert_count = kwargs.pop("moe_expert_count", 0) + self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) + self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) + self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") + self.moe_normalize_gate_prob_before_dropping = kwargs.pop( + "moe_normalize_gate_prob_before_dropping", False) + self.use_xmoe = kwargs.pop("use_xmoe", False) + self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) + self.max_rel_pos = kwargs.pop("max_rel_pos", 0) + self.deepnorm = kwargs.pop("deepnorm", False) + self.subln = kwargs.pop("subln", True) + self.use_ffn_rms_norm = kwargs.pop("use_ffn_rms_norm", False) + self.use_glu = kwargs.pop("use_glu", True) + self.use_lm_decay = kwargs.pop("use_lm_decay", False) + self.z_loss_coeff = kwargs.pop("z_loss_coeff", 0.0) # TODO: 1e-4 + self.multiway = kwargs.pop("multiway", False) + self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", + False) + self.max_target_positions = kwargs.pop("max_target_positions", 1024) + self.no_output_layer = kwargs.pop("no_output_layer", True) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6) + # Blockwise + self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False) + self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512) + # Text + self.vocab_size = kwargs.pop("vocab_size", -1) + # Fairscale + self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) + self.fsdp = kwargs.pop("fsdp", False) + self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + # token id + self.pad_token_id = kwargs.pop("pad_token_id", 0) + self.postprocessing() + + def postprocessing(self): + if self.deepnorm: + self.decoder_normalize_before = False + self.subln = False + if self.subln: + self.decoder_normalize_before = True + self.deepnorm = False + if self.use_xmoe: + self.moe_normalize_gate_prob_before_dropping = True + self.moe_second_expert_policy = "random" + assert self.moe_freq > 0 and self.moe_expert_count > 0 + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + self.postprocessing() \ No newline at end of file diff --git a/vall_e/ext/retnet_ts/retnet.py b/vall_e/ext/retnet_ts/retnet.py new file mode 100644 index 0000000..d4d194b --- /dev/null +++ b/vall_e/ext/retnet_ts/retnet.py @@ -0,0 +1,746 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +# from fairscale.nn import checkpoint_wrapper, wrap +from timm.models.layers import drop_path +from transformers.modeling_outputs import CausalLMOutputWithPast + +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + + +def theta_shift(x, sin, cos): + return (x * cos) + (rotate_every_two(x) * sin) + + +def get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "swish": + return F.silu + else: + raise NotImplementedError + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): + super().__init__() + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter('weight', None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + +class RetNetRelPos(nn.Module): + + def __init__(self, config): + super().__init__() + num_heads = config.decoder_retention_heads + + angle = 1.0 / (10000**torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2)) + angle = angle.unsqueeze(-1).repeat(1, 2).flatten() + if config.use_lm_decay: + # NOTE: alternative way described in the paper + s = torch.log(torch.tensor(1 / 32)) + e = torch.log(torch.tensor(1 / 512)) + decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,] + else: + decay = torch.log(1 - 2**(-5 - torch.arange(num_heads, dtype=torch.float))) + self.register_buffer("angle", angle) + self.register_buffer("decay", decay) + self.recurrent_chunk_size = config.recurrent_chunk_size + + def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False): + if activate_recurrent: + sin = torch.sin(self.angle * (slen - 1)) + cos = torch.cos(self.angle * (slen - 1)) + retention_rel_pos = ((sin, cos), self.decay.exp()) + elif chunkwise_recurrent: + index = torch.arange(slen).to(self.decay) + sin = torch.sin(index[:, None] * self.angle[None, :]) + cos = torch.cos(index[:, None] * self.angle[None, :]) + + block_index = torch.arange(self.recurrent_chunk_size).to(self.decay) + mask = torch.tril(torch.ones(self.recurrent_chunk_size, + self.recurrent_chunk_size)).to(self.decay) + mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), + float("inf")) + mask = torch.exp(mask * self.decay[:, None, None]) + mask = torch.nan_to_num(mask) + + value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True) + value_inner_decay = value_inner_decay.unsqueeze(-1) + scale = mask.sum(dim=-1, keepdim=True).sqrt() + inner_mask = mask / scale + + cross_decay = torch.exp(self.decay * self.recurrent_chunk_size) + query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1)) + query_inner_decay = query_inner_decay[:, :, None] / ( + scale / mask[:, -1].sum(dim=-1)[:, None, None]) + cross_decay = cross_decay[:, None, None] + retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, + value_inner_decay)) + else: + index = torch.arange(slen).to(self.decay) + sin = torch.sin(index[:, None] * self.angle[None, :]) + cos = torch.cos(index[:, None] * self.angle[None, :]) + mask = torch.tril(torch.ones(slen, slen)).to(self.decay) + mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf")) + mask = torch.exp(mask * self.decay[:, None, None]) + mask = torch.nan_to_num(mask) + mask = mask / mask.sum(dim=-1, keepdim=True).sqrt() + retention_rel_pos = ((sin, cos), mask) + + return retention_rel_pos + + +class MultiScaleRetention(nn.Module): + + def __init__( + self, + config, + embed_dim, + value_dim, + num_heads, + gate_fn="swish", + ): + super().__init__() + self.config = config + self.embed_dim = embed_dim + self.value_dim = value_dim + self.num_heads = num_heads + self.head_dim = self.value_dim // num_heads + self.key_dim = self.embed_dim // num_heads + self.scaling = self.key_dim**-0.5 + + self.gate_fn = get_activation_fn(activation=str(gate_fn)) + + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, value_dim, bias=False) + self.g_proj = nn.Linear(embed_dim, value_dim, bias=False) + + self.out_proj = nn.Linear(value_dim, embed_dim, bias=False) + + self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.out_proj.weight, gain=2**-1) + + def parallel_forward(self, qr, kr, v, mask): + bsz, tgt_len, embed_dim = v.size() + + vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + + qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len + qk_mat = qk_mat * mask + # invariant after normalization + qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4) + output = torch.matmul(qk_mat, vr) + output = output.transpose(1, 2) + return output + + def recurrent_forward(self, qr, kr, v, decay, incremental_state): + bsz = v.size(0) + + v = v.view(bsz, self.num_heads, self.head_dim, 1) + kv = kr * v + if "prev_key_value" in incremental_state: + prev_kv = incremental_state["prev_key_value"] + prev_scale = incremental_state["scale"] + scale = prev_scale * decay + 1 + kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view( + self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1) + # kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv + else: + scale = torch.ones_like(decay) + + incremental_state["prev_key_value"] = kv + incremental_state["scale"] = scale + + output = torch.sum(qr * kv, dim=3) + return output + + def chunk_recurrent_forward(self, qr, kr, v, inner_mask): + mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask + bsz, tgt_len, embed_dim = v.size() + chunk_len = mask.size(1) + num_chunks = tgt_len // chunk_len + + assert tgt_len % chunk_len == 0 + + qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2) + kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2) + v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3) + + kr_t = kr.transpose(-1, -2) + + qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len + qk_mat = qk_mat * mask + inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1) + qk_mat = qk_mat / inner_scale + inner_output = torch.matmul(qk_mat, + v) # bsz * num_heads * num_value_heads * chunk_len * head_dim + + # reduce kv in one chunk + kv = kr_t @ (v * value_inner_decay) + + kv_recurrent = [] + cross_scale = [] + kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v) + kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v) + + # accumulate kv by loop + for i in range(num_chunks): + kv_recurrent.append(kv_state / kv_scale) + cross_scale.append(kv_scale) + kv_state = kv_state * cross_decay + kv[:, i] + kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max( + dim=-1, keepdim=True).values.clamp(min=1) + + kv_recurrent = torch.stack(kv_recurrent, dim=1) + cross_scale = torch.stack(cross_scale, dim=1) + + all_scale = torch.maximum(inner_scale, cross_scale) + align_inner_scale = all_scale / inner_scale + align_cross_scale = all_scale / cross_scale + + cross_output = (qr * query_inner_decay) @ kv_recurrent + output = inner_output / align_inner_scale + cross_output / align_cross_scale + # output = inner_output / cross_scale + cross_output / inner_scale + + output = output.transpose(2, 3) + return output + + def forward(self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None): + bsz, tgt_len, _ = x.size() + (sin, cos), inner_mask = rel_pos + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + g = self.g_proj(x) + + k *= self.scaling + q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) + k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) + + qr = theta_shift(q, sin, cos) + kr = theta_shift(k, sin, cos) + + if incremental_state is not None: + output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state) + elif chunkwise_recurrent: + output = self.chunk_recurrent_forward(qr, kr, v, inner_mask) + else: + output = self.parallel_forward(qr, kr, v, inner_mask) + + output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads) + + output = self.gate_fn(g) * output + + output = self.out_proj(output) + + return output + + +class FeedForwardNetwork(nn.Module): + + def __init__( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + layernorm_eps, + subln=False, + use_rms_norm=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = nn.Linear(self.embed_dim, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim) + if subln: + if use_rms_norm: + self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps) + else: + self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps) + else: + self.ffn_layernorm = None + + def reset_parameters(self): + self.fc1.reset_parameters() + self.fc2.reset_parameters() + if self.ffn_layernorm is not None: + self.ffn_layernorm.reset_parameters() + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + x = self.fc1(x) + x = self.activation_fn(x.float()).type_as(x) + x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + x = x.view(x_shape) + x = self.dropout_module(x) + return x + + +class GLU(nn.Module): + + def __init__( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False) + self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False) + + def reset_parameters(self): + self.fc1.reset_parameters() + self.fc2.reset_parameters() + self.gate.reset_parameters() + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + g = self.gate(x) + x = self.fc1(x) + x = self.activation_fn(x.float()).type_as(x) * g + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = x.view(x_shape) + x = self.dropout_module(x) + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return "p={}".format(self.drop_prob) + + +class RetNetDecoderLayer(nn.Module): + + def __init__( + self, + config, + depth, + ): + super().__init__() + self.config = config + self.embed_dim = config.decoder_embed_dim + self.dropout_module = torch.nn.Dropout(config.dropout) + + if config.drop_path_rate > 0: + drop_path_prob = np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth] + self.drop_path = DropPath(drop_path_prob) + else: + self.drop_path = None + + self.retention = MultiScaleRetention( + config, + self.embed_dim, + config.decoder_value_embed_dim, + config.decoder_retention_heads, + ) + + self.normalize_before = config.decoder_normalize_before + + self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) + + self.ffn_dim = config.decoder_ffn_embed_dim + + self.ffn = self.build_ffn() + + self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) + + if config.deepnorm: + self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) + else: + self.alpha = 1.0 + + def build_ffn(self): + if self.config.use_glu: + return GLU( + self.embed_dim, + self.ffn_dim, + self.config.activation_fn, + self.config.dropout, + self.config.activation_dropout, + ) + else: + return FeedForwardNetwork( + self.embed_dim, + self.ffn_dim, + self.config.activation_fn, + self.config.dropout, + self.config.activation_dropout, + self.config.layernorm_eps, + self.config.subln, + self.config.use_ffn_rms_norm, + ) + + def residual_connection(self, x, residual): + return residual * self.alpha + x + + def forward( + self, + x, + incremental_state=None, + chunkwise_recurrent=False, + retention_rel_pos=None, + ): + residual = x + if self.normalize_before: + x = self.retention_layer_norm(x) + + x = self.retention( + x, + incremental_state=incremental_state, + rel_pos=retention_rel_pos, + chunkwise_recurrent=chunkwise_recurrent, + ) + x = self.dropout_module(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.retention_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.ffn(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + return x + + +class RetNetModel(nn.Module): + + def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.dropout_module = torch.nn.Dropout(config.dropout) + + embed_dim = config.decoder_embed_dim + self.embed_dim = embed_dim + self.embed_scale = 1.0 if config.no_scale_embedding else math.sqrt(embed_dim) + + self.embed_tokens = embed_tokens + + if (output_projection is None and not config.no_output_layer and config.vocab_size > 0): + self.output_projection = self.build_output_projection(config) + else: + self.output_projection = output_projection + + if config.layernorm_embedding: + self.layernorm_embedding = RMSNorm(embed_dim, eps=config.layernorm_eps) + else: + self.layernorm_embedding = None + + self.layers = nn.ModuleList([]) + + for i in range(config.decoder_layers): + self.layers.append(self.build_decoder_layer( + config, + depth=i, + )) + + self.num_layers = len(self.layers) + + if config.decoder_normalize_before: + self.layer_norm = RMSNorm(embed_dim, eps=config.layernorm_eps) + else: + self.layer_norm = None + + self.retnet_rel_pos = RetNetRelPos(config) + self.chunkwise_recurrent = config.chunkwise_recurrent + self.recurrent_chunk_size = config.recurrent_chunk_size + + if config.deepnorm: + init_scale = math.pow(8.0 * config.decoder_layers, 0.25) + for name, p in self.named_parameters(): + if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name): + p.data.div_(init_scale) + + if config.subln and not config.use_glu: + init_scale = math.sqrt(math.log(config.decoder_layers * 2)) + for name, p in self.named_parameters(): + if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name): + p.data.mul_(init_scale) + + def build_output_projection( + self, + config, + ): + if config.share_decoder_input_output_embed: + output_projection = torch.nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + output_projection.weight = self.embed_tokens.weight + else: + output_projection = torch.nn.Linear(config.decoder_embed_dim, + config.vocab_size, + bias=False) + torch.nn.init.normal_(output_projection.weight, + mean=0, + std=config.decoder_embed_dim**-0.5) + return output_projection + + def build_decoder_layer(self, config, depth): + layer = RetNetDecoderLayer( + config, + depth, + ) + # if config.checkpoint_activations: + # layer = checkpoint_wrapper(layer) + # if config.fsdp: + # layer = wrap(layer) + return layer + + def forward_embedding( + self, + tokens, + token_embedding=None, + incremental_state=None, + ): + if incremental_state is not None and not self.is_first_step(incremental_state): + tokens = tokens[:, -1:] + + if token_embedding is None: + token_embedding = self.embed_tokens(tokens) + + x = embed = self.embed_scale * token_embedding + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + return x, embed + + def is_first_step(self, incremental_state): + if incremental_state is None: + return False + return incremental_state.get("is_first_step", False) + + def forward(self, + prev_output_tokens, + incremental_state=None, + features_only=False, + token_embeddings=None): + # embed tokens + x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state) + is_first_step = self.is_first_step(incremental_state) + + if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0: + padding_len = self.recurrent_chunk_size - prev_output_tokens.size( + 1) % self.recurrent_chunk_size + slen = prev_output_tokens.size(1) + padding_len + x = F.pad(x, (0, 0, 0, padding_len)) + else: + slen = prev_output_tokens.size(1) + # relative position + retention_rel_pos = self.retnet_rel_pos(slen, + incremental_state is not None and not is_first_step, + chunkwise_recurrent=self.chunkwise_recurrent) + + # decoder layers + inner_states = [x] + + for idx, layer in enumerate(self.layers): + if incremental_state is None or is_first_step: + if is_first_step and incremental_state is not None: + if idx not in incremental_state: + incremental_state[idx] = {} + else: + if idx not in incremental_state: + incremental_state[idx] = {} + + x = layer( + x, + incremental_state[idx] if incremental_state is not None else None, + retention_rel_pos=retention_rel_pos, + chunkwise_recurrent=self.chunkwise_recurrent, + ) + inner_states.append(x) + + if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0: + x = x[:, :prev_output_tokens.size(1), :] + + if self.layer_norm is not None: + x = self.layer_norm(x) + + if not features_only: + x = self.output_layer(x) + + return x, { + "inner_states": inner_states, + "attn": None, + } + + def output_layer(self, features): + return self.output_projection(features) + + +class RetNetForCausalLM(nn.Module): + + def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs): + super().__init__(**kwargs) + assert config.vocab_size > 0, "you must specify vocab size" + if output_projection is None: + config.no_output_layer = False + if embed_tokens is None: + embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim, + config.pad_token_id) + + self.config = config + self.model = RetNetModel(config, + embed_tokens=embed_tokens, + output_projection=output_projection, + **kwargs) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.model.output_projection + + def set_output_embeddings(self, new_embeddings): + self.model.output_projection = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + retention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_retentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + recurrent_chunk_size: Optional[int] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + outputs = self.model( + input_ids, + incremental_state=past_key_values, + features_only=False, + token_embeddings=inputs_embeds, + ) + + logits, inner_hidden_states = outputs[0], outputs[1]['inner_states'] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if self.config.z_loss_coeff > 0: + # z_loss from PaLM paper + # z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits)) + z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean() + loss += self.config.z_loss_coeff * z_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=inner_hidden_states, + attentions=None, + ) \ No newline at end of file diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index fe0fa5e..e9c965a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -351,7 +351,7 @@ def example_usage(): 'n_tokens': 1024, 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 16, # 32 + 'n_layers': 12, # 32 'n_experts': 1, 'l_padding': 8 if cfg.fp8.enabled else 0, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 1103e7f..e2cfa2b 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -24,7 +24,8 @@ except Exception as e: pass try: - from .retnet import RetNetDecoder, RetNetConfig + #from .retnet import RetNetDecoder, RetNetConfig + from .retnet_ts import RetNetDecoder, RetNetConfig except Exception as e: print("Error importing `retnet` arch:", e) pass @@ -394,6 +395,7 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, + attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None )) else: self.model = MixtralModel(MixtralConfig( @@ -412,6 +414,7 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), + attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None )) elif self.arch_type == "llama": if n_experts <= 1: @@ -428,6 +431,7 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, + attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None )) else: self.model = MixtralModel(MixtralConfig( @@ -446,6 +450,7 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), + attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None )) elif self.arch_type == "retnet": diff --git a/vall_e/models/retnet_ts.py b/vall_e/models/retnet_ts.py new file mode 100644 index 0000000..76a2aea --- /dev/null +++ b/vall_e/models/retnet_ts.py @@ -0,0 +1,279 @@ +# https://github.com/syncdoth/RetNet/ +from ..ext.retnet_ts.config import RetNetConfig +from ..ext.retnet_ts.retnet import RetNetModel as RetNetDecoder + +# things we're overriding or required to override +from ..ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos + +import torch +import math +from typing import Dict, List, Optional, Tuple, Union + +from torch.utils.checkpoint import checkpoint + +# required to have compatibile LayerNorm +def FeedForwardNetwork_init( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + layernorm_eps, + subln=True, + use_rms_norm=False, +): + super(FeedForwardNetwork, self).__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim) + self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim) + self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None + +FeedForwardNetwork.__init__ = FeedForwardNetwork_init + +# removes embed_tokens +def RetNetModel_init( + self, config, embed_tokens=None, output_projection=None, **kwargs + ): + super(RetNetDecoder, self).__init__(**kwargs) + self.config = config + + self.dropout_module = torch.nn.Dropout(config.dropout) + + self.embed_dim = config.decoder_embed_dim + self.embed_scale = ( + 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) + ) + + """ + if embed_tokens is None: + embed_tokens = torch.nn.Embedding( + config.vocab_size, config.decoder_embed_dim, config.pad_token_id + ) + """ + self.embed_tokens = None + + if (output_projection is None and not config.no_output_layer and config.vocab_size > 0): + self.output_projection = self.build_output_projection(config) + else: + self.output_projection = output_projection + + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layernorm_embedding = None + + self.layers = torch.nn.ModuleList([]) + + for i in range(config.decoder_layers): + layer = self.build_decoder_layer( + config, + depth=i, + ) + """ + if config.checkpoint_activations: + layer = checkpoint_wrapper(layer) + """ + self.layers.append(layer) + + self.num_layers = len(self.layers) + + if config.decoder_normalize_before: + self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layer_norm = None + + self.retnet_rel_pos = RetNetRelPos(config) + self.chunkwise_recurrent = config.chunkwise_recurrent + self.recurrent_chunk_size = config.recurrent_chunk_size + + if config.deepnorm: + init_scale = math.pow(8.0 * config.decoder_layers, 0.25) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.div_(init_scale) + + if config.subln and not config.use_glu: + init_scale = math.sqrt(math.log(config.decoder_layers * 2)) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.mul_(init_scale) + + self.gradient_checkpointing = True + +RetNetDecoder.__init__ = RetNetModel_init + +# restores bias in our FFNs +def RetNetDecoderLayer_init( + self, + config, + depth, + use_bias=True +): + super(RetNetDecoderLayer, self).__init__() + self.config = config + self.embed_dim = config.decoder_embed_dim + self.dropout_module = torch.nn.Dropout(config.dropout) + + if config.drop_path_rate > 0: + drop_path_prob = np.linspace( + 0, config.drop_path_rate, config.decoder_layers + )[depth] + self.drop_path = DropPath(drop_path_prob) + else: + self.drop_path = None + + self.retention = MultiScaleRetention( + config, + self.embed_dim, + config.decoder_value_embed_dim, + config.decoder_retention_heads, + use_bias=use_bias + ) + + self.normalize_before = config.decoder_normalize_before + + self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + + self.ffn_dim = config.decoder_ffn_embed_dim + + self.ffn = self.build_ffn() + + self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + + if config.deepnorm: + self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) + else: + self.alpha = 1.0 + +def RetNetDecoderLayer_forward( + self, + x, + incremental_state=None, + chunkwise_recurrent=False, + retention_rel_pos=None, +): + residual = x + if self.normalize_before: + x = self.retention_layer_norm(x) + + if x.requires_grad and self.config.checkpoint_activations: + x = checkpoint( + self.retention, + x, + use_reentrant=False, + incremental_state=incremental_state, + rel_pos=retention_rel_pos, + chunkwise_recurrent=chunkwise_recurrent, + ) + else: + x = self.retention( + x, + incremental_state=incremental_state, + rel_pos=retention_rel_pos, + chunkwise_recurrent=chunkwise_recurrent, + ) + x = self.dropout_module(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.retention_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.ffn(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + return x + +RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init +RetNetDecoderLayer.forward = RetNetDecoderLayer_forward +# fixes backwards when using te's autocast +def MultiScaleRetention_init( + self, + config, + embed_dim, + value_dim, + num_heads, + gate_fn="swish", + use_bias=True, +): + super(MultiScaleRetention, self).__init__() + self.config = config + self.embed_dim = embed_dim + self.value_dim = value_dim + self.num_heads = num_heads + self.head_dim = self.value_dim // num_heads + self.key_dim = self.embed_dim // num_heads + self.scaling = self.key_dim**-0.5 + + self.gate_fn = get_activation_fn(activation=str(gate_fn)) + + self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias) + self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias) + self.v_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias) + self.g_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias) + + self.out_proj = torch.nn.Linear(value_dim, embed_dim, bias=use_bias) + + self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False) + self.reset_parameters() + +def MultiScaleRetention_forward( + self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None +) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: + bsz, tgt_len, _ = x.size() + (sin, cos), inner_mask = rel_pos + + q = self.q_proj(x) + k = self.k_proj(x) * self.scaling + v = self.v_proj(x) + g = self.g_proj(x) + + q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) + k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) + + qr = theta_shift(q, sin, cos) + kr = theta_shift(k, sin, cos) + + if incremental_state is not None: + output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state) + elif chunkwise_recurrent: + output = self.chunk_recurrent_forward(qr, kr, v, inner_mask) + else: + output = self.parallel_forward(qr, kr, v, inner_mask) + + output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads) + + output = self.gate_fn(g) * output + + output = self.out_proj(output) + + return output + +MultiScaleRetention.__init__ = MultiScaleRetention_init +MultiScaleRetention.forward = MultiScaleRetention_forward \ No newline at end of file diff --git a/vall_e/webui.py b/vall_e/webui.py index 580c331..0e1a7c2 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -203,7 +203,7 @@ with ui: layout["inference"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt") with gr.Row(): with gr.Column(scale=1): - layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath", info="Reference audio for TTS") + layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS") # layout["inference"]["stop"] = gr.Button(value="Stop") layout["inference"]["outputs"]["output"] = gr.Audio(label="Output") layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")