diff --git a/vall_e/config.py b/vall_e/config.py index 383a93b..5c47122 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -176,6 +176,11 @@ class Model: p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training + @property + # required for fp8 as the lengths needs to be divisible by 8 + def input_alignment(self): + return 8 if cfg.fp8.enabled else 0 + @property def full_name(self): name = [ self.name ] @@ -503,6 +508,10 @@ class Trainer: return torch.float16 if self.weight_dtype == "bfloat16": return torch.bfloat16 + if self.weight_dtype == "float8_e5m2": + return torch.float8_e5m2 + if self.weight_dtype == "float8_e4m3fn": + return torch.float8_e4m3fn return torch.float32 @@ -527,6 +536,10 @@ class Inference: return torch.bfloat16 if self.weight_dtype == "int8": return torch.int8 + if self.weight_dtype == "float8_e5m2": + return torch.float8_e5m2 + if self.weight_dtype == "float8_e4m3fn": + return torch.float8_e4m3fn return torch.float32 @dataclass() @@ -540,6 +553,11 @@ class BitsAndBytes: bitnet: bool = False +@dataclass() +class FP8: + enabled: bool = False + backend: str = "te" + @dataclass() class Config(_Config): device: str = "cuda" @@ -553,6 +571,8 @@ class Config(_Config): trainer: Trainer = field(default_factory=lambda: Trainer) inference: Inference = field(default_factory=lambda: Inference) bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) + + fp8: FP8 = field(default_factory=lambda: FP8) @property def sample_rate(self): @@ -620,6 +640,7 @@ try: except Exception as e: + print(e) pass diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 7793868..fa2f86d 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -42,6 +42,7 @@ from typing import Any, Protocol from functools import cached_property from .base import TrainFeeder +from ..utils import wrapper as ml _logger = logging.getLogger(__name__) @@ -222,10 +223,11 @@ class Engine(): return self._global_grad_norm def traverse(self, *args, **kwargs): - with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + with ml.autocast(): self.forward(*args, **kwargs) - losses = self.gather_attribute("loss") - loss = torch.stack([*losses.values()]).sum() + + losses = self.gather_attribute("loss") + loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 59acdea..585eced 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -25,6 +25,7 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr from deepspeed.accelerator import get_accelerator from ..utils.distributed import init_distributed, distributed_initialized +from ..utils import wrapper as ml if not distributed_initialized() and cfg.trainer.backend == "deepspeed": init_distributed(init_deepspeed_dist) @@ -106,10 +107,11 @@ class Engine(DeepSpeedEngine): print(str(e)) def traverse(self, *args, **kwargs): - with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + with ml.autocast(): self.forward(*args, **kwargs) - losses = self.gather_attribute("loss") - loss = torch.stack([*losses.values()]).sum() + + losses = self.gather_attribute("loss") + loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} diff --git a/vall_e/ext/__init__.py b/vall_e/ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vall_e/ext/retnet_hf/__init__.py b/vall_e/ext/retnet_hf/__init__.py new file mode 100644 index 0000000..fcac9b7 --- /dev/null +++ b/vall_e/ext/retnet_hf/__init__.py @@ -0,0 +1,3 @@ +# from https://github.com/syncdoth/RetNet/ + +# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system \ No newline at end of file diff --git a/vall_e/ext/retnet_hf/configuration_retnet.py b/vall_e/ext/retnet_hf/configuration_retnet.py new file mode 100644 index 0000000..b842606 --- /dev/null +++ b/vall_e/ext/retnet_hf/configuration_retnet.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass +import json + +from transformers.configuration_utils import PretrainedConfig + + +def load_config_from_json(config_file): + with open(config_file, 'r') as f: + config = json.load(f) + config = RetNetConfig.from_dict(config) + return config + + +@dataclass +class RetNetConfig(PretrainedConfig): + model_type = "retnet" + initializer_range: float = 0.02 + activation_fn: str = "gelu" + dropout: float = 0.0 # dropout probability + activation_dropout: float = 0.0 # dropout probability after activation in FFN. + drop_path_rate: float = 0.0 + decoder_embed_dim: int = 768 # decoder embedding dimension + decoder_value_embed_dim: int = 1280 # decoder value embedding dimension + decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN + decoder_layers: int = 12 # num decoder layers + decoder_retention_heads: int = 3 # num decoder retention heads + decoder_normalize_before: bool = True # apply layernorm before each decoder block + layernorm_embedding: bool = False # add layernorm to embedding + no_scale_embedding: bool = True # if True, dont scale embeddings + recurrent_chunk_size: int = 512 + use_lm_decay: bool = False + use_glu: bool = True # use GLU instead of FFN + z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4 + deepnorm: bool = False + subln: bool = True + use_ffn_rms_norm: bool = False + layernorm_eps: float = 1e-6 + tie_word_embeddings: bool = False + + def __init__( + self, + vocab_size: int = 50257, + initializer_range: float = 0.02, + is_decoder: bool = True, + pad_token_id: int = 0, + eos_token_id: int = 0, + output_retentions: bool = False, + use_cache: bool = True, + forward_impl: str = 'parallel', + activation_fn: str = "gelu", + dropout: float = 0.0, # dropout probability + activation_dropout: float = 0.0, # dropout probability after activation in FFN. + drop_path_rate: float = 0.0, + decoder_embed_dim: int = 768, # decoder embedding dimension + decoder_value_embed_dim: int = 1280, # decoder value embedding dimension + decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN + decoder_layers: int = 12, # num decoder layers + decoder_retention_heads: int = 3, # num decoder retention heads + decoder_normalize_before: bool = True, # apply layernorm before each decoder block + layernorm_embedding: bool = False, # add layernorm to embedding + no_scale_embedding: bool = True, # if True, dont scale embeddings + recurrent_chunk_size: int = 512, + use_glu: bool = True, # use GLU instead of FFN + z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4 + use_lm_decay: bool = False, + deepnorm: bool = True, + subln: bool = True, + use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN + layernorm_eps: float = 1e-6, + tie_word_embeddings: bool = False, + **kwargs): + self.vocab_size = vocab_size + self.initializer_range = initializer_range + self.output_retentions = output_retentions + self.use_lm_decay = use_lm_decay + self.use_glu = use_glu + self.z_loss_coeff = z_loss_coeff + # size related + self.decoder_embed_dim = decoder_embed_dim + self.decoder_value_embed_dim = decoder_value_embed_dim + self.decoder_retention_heads = decoder_retention_heads + self.decoder_ffn_embed_dim = decoder_ffn_embed_dim + self.decoder_layers = decoder_layers + # normalization related + self.decoder_normalize_before = decoder_normalize_before + self.activation_fn = activation_fn + self.dropout = dropout + self.drop_path_rate = drop_path_rate + self.activation_dropout = activation_dropout + self.no_scale_embedding = no_scale_embedding + self.layernorm_embedding = layernorm_embedding + self.deepnorm = deepnorm + self.subln = subln + self.use_ffn_rms_norm = use_ffn_rms_norm + self.layernorm_eps = layernorm_eps + # Blockwise + self.recurrent_chunk_size = recurrent_chunk_size + self.forward_impl = forward_impl + + if self.deepnorm: + self.decoder_normalize_before = False + self.subln = False + if self.subln: + self.decoder_normalize_before = True + self.deepnorm = False + + super().__init__(is_decoder=is_decoder, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + **kwargs) + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) diff --git a/vall_e/ext/retnet_hf/modeling_retnet.py b/vall_e/ext/retnet_hf/modeling_retnet.py new file mode 100644 index 0000000..2a2a580 --- /dev/null +++ b/vall_e/ext/retnet_hf/modeling_retnet.py @@ -0,0 +1,1455 @@ +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from timm.models.layers import drop_path +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import top_k_top_p_filtering +from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm + +from .configuration_retnet import RetNetConfig + +logger = logging.get_logger(__name__) + + +# helper functions +def split_heads(tensors, bsz, seqlen, num_heads): + assert isinstance(tensors, (tuple, list)) + return [x.view(bsz, seqlen, num_heads, -1).transpose(1, 2) for x in tensors] + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + + +def theta_shift(x, sin, cos): + return (x * cos) + (rotate_every_two(x) * sin) + + +def get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "swish": + return F.silu + else: + raise NotImplementedError + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): + super().__init__() + self.normalized_shape = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + +class RetNetRelPos(nn.Module): + def __init__(self, config: RetNetConfig): + super().__init__() + self.config = config + num_heads = config.decoder_retention_heads + + angle = 1.0 / ( + 10000 ** torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2) + ) + angle = angle.unsqueeze(-1).repeat(1, 2).flatten() + # decay (gamma) + if config.use_lm_decay: + # NOTE: alternative way described in the paper + s = torch.log(torch.tensor(1 / 32)) + e = torch.log(torch.tensor(1 / 512)) + decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,] + else: + decay = torch.log( + 1 - 2 ** (-5 - torch.arange(num_heads, dtype=torch.float)) + ) + self.register_buffer("angle", angle) + self.register_buffer("decay", decay) + self.recurrent_chunk_size = config.recurrent_chunk_size + + def forward( + self, + slen, + forward_impl="parallel", + recurrent_chunk_size=None, + retention_mask=None, + get_decay_scale=True, + ): + if forward_impl == "recurrent": + sin = torch.sin(self.angle * (slen - 1)) + cos = torch.cos(self.angle * (slen - 1)) + retention_rel_pos = ((sin, cos), self.decay.view(1, -1, 1, 1).exp()) + elif forward_impl == "chunkwise": + if recurrent_chunk_size is None: + recurrent_chunk_size = self.recurrent_chunk_size + index = torch.arange(slen).to(self.decay) + sin = torch.sin(index[:, None] * self.angle[None, :]) + cos = torch.cos(index[:, None] * self.angle[None, :]) + + block_index = torch.arange(recurrent_chunk_size).to(self.decay) + mask = torch.tril( + torch.ones(recurrent_chunk_size, recurrent_chunk_size) + ).to(self.decay) + mask = torch.masked_fill( + block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf") + ) + mask = torch.exp(mask * self.decay[:, None, None]) + mask = torch.nan_to_num(mask) + mask = mask.unsqueeze(0) # [1, h, t, t] + # TODO: need to handle retention_mask + # scaling + value_inner_decay = mask[:, :, -1] / mask[:, :, -1].sum( + dim=-1, keepdim=True + ) + value_inner_decay = value_inner_decay.unsqueeze(-1) + scale = mask.sum(dim=-1, keepdim=True).sqrt() + inner_mask = mask / scale + + cross_decay = torch.exp(self.decay * recurrent_chunk_size) + query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1)) + cross_decay = cross_decay[None, :, None, None] + query_inner_decay = query_inner_decay[None, :, :, None] / ( + scale / mask[:, :, -1].sum(dim=-1)[:, :, None, None] + ) + # decay_scale (used for kv cache) + if get_decay_scale: + decay_scale = self.compute_decay_scale(slen, retention_mask) + else: + decay_scale = None + retention_rel_pos = ( + (sin, cos), + ( + inner_mask, + cross_decay, + query_inner_decay, + value_inner_decay, + decay_scale, + ), + ) + else: # parallel + index = torch.arange(slen).to(self.decay) + sin = torch.sin(index[:, None] * self.angle[None, :]) + cos = torch.cos(index[:, None] * self.angle[None, :]) + mask = torch.tril(torch.ones(slen, slen)).to(self.decay) + mask = torch.masked_fill( + index[:, None] - index[None, :], ~mask.bool(), float("inf") + ) + mask = torch.exp(mask * self.decay[:, None, None]) + mask = torch.nan_to_num(mask) + mask = mask.unsqueeze(0) # [1, h, t, t] + if retention_mask is not None: + # this is required for left padding + mask = mask * retention_mask.float().view(-1, 1, 1, slen).to(mask) + + # scaling + mask = mask / mask.sum(dim=-1, keepdim=True).sqrt() + mask = torch.nan_to_num(mask, nan=0.0) + # decay_scale (used for kv cache) + if get_decay_scale: + decay_scale = self.compute_decay_scale(slen, retention_mask) + else: + decay_scale = None + # mask processing for intra decay + if retention_mask is not None: + max_non_zero = ( + torch.cumsum(retention_mask, dim=-1).max(dim=-1).indices + ) # [b,] + intra_decay = mask[range(mask.shape[0]), :, max_non_zero] + else: + intra_decay = mask[:, :, -1] + + retention_rel_pos = ((sin, cos), (mask, intra_decay, decay_scale)) + + return retention_rel_pos + + def compute_decay_scale(self, slen, retention_mask=None): + exponent = torch.arange(slen, device=self.decay.device).float() + decay_scale = self.decay.exp().view(-1, 1) ** exponent.view(1, -1) # [h, t] + if retention_mask is not None: + seqlen = retention_mask.sum(dim=-1) # [b,] + bsz = seqlen.size(0) + decay_scale = decay_scale.unsqueeze(0).repeat(bsz, 1, 1) # [b, h, t] + for i, pos in enumerate(seqlen): + # the formula for decay_scale is `sum(gamma^i) for i in [0, slen).` + # Since the retention_mask is 0 for padding, we can set the decay_scale + # to 0 for the padding positions. + decay_scale[i, :, pos.item() :] = 0 + else: + bsz = 1 + decay_scale = decay_scale.sum(-1).view(bsz, -1, 1, 1) # [b, h, 1, 1] + return decay_scale + + +class MultiScaleRetention(nn.Module): + def __init__( + self, + config: RetNetConfig, + gate_fn="swish", + use_bias=False, + tensor_parallel=False, + ): + super().__init__() + self.config = config + self.embed_dim = config.decoder_embed_dim + self.value_dim = config.decoder_value_embed_dim + self.num_heads = config.decoder_retention_heads + self.head_dim = self.value_dim // self.num_heads + self.key_dim = self.embed_dim // self.num_heads + self.scaling = self.key_dim**-0.5 + + self.gate_fn = get_activation_fn(activation=str(gate_fn)) + + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias) + self.v_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias) + self.g_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias) + + self.out_proj = nn.Linear(self.value_dim, self.embed_dim, bias=use_bias) + + self.group_norm = RMSNorm( + self.head_dim, eps=config.layernorm_eps, elementwise_affine=False + ) + self.reset_parameters() + + if tensor_parallel: + self.decay_proj = nn.Linear(self.num_heads, self.num_heads, bias=False) + else: + self.decay_proj = None + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5) + nn.init.xavier_uniform_(self.out_proj.weight, gain=2**-1) + + def parallel_retention(self, q, k, v, decay_mask): + """ + q, # bsz * num_head * len * qk_dim + k, # bsz * num_head * len * qk_dim + v, # bsz * num_head * len * v_dim + decay_mask, # (1 or bsz) * num_head * len * len + """ + decay_mask, intra_decay, scale = decay_mask + # just return retention_rel_pos projected + # TODO: for shardformer + if self.decay_proj is not None: + decay_mask = self.decay_proj(decay_mask.transpose(-1, -3)).transpose(-3, -1) + + # [b, h, t, t] + retention = q @ k.transpose(-1, -2) # (scaled dot-product) + retention = retention * decay_mask + + # invariant after normalization + retention = retention / retention.detach().abs().sum( + dim=-1, keepdim=True + ).clamp(min=1, max=5e4) + + output = retention @ v # [b, h, t, v_dim / h] + output = output.transpose(1, 2) # [b, t, h, v_dim / h] + + if self.training: # skip cache + return output, None, retention + + if self.decay_proj is not None: + intra_decay = self.decay_proj(intra_decay.transpose(-1, -2)).transpose( + -2, -1 + ) + + # kv cache: [b, h, t, v_dim, qk_dim] + current_kv = k.unsqueeze(-2) * v.unsqueeze(-1) + intra_decay = intra_decay[:, :, :, None, None] # [b, h, t, 1, 1] + current_kv = (current_kv * intra_decay).sum(2) # [b, h, v_dim, qk_dim] + + cache = {"prev_key_value": current_kv, "scale": scale} + return output, cache, retention + + def recurrent_retention( + self, q, k, v, decay, past_key_value=None, retention_mask=None + ): + """ + q, k, v, # bsz * num_head * 1 * qkv_dim + past_key_value: + - "prev_key_value" # bsz * num_head * v_dim * qk_dim + - "scale" # (1 or bsz) * num_head * 1 * 1 + decay # (1 or bsz) * num_head * 1 * 1 + retention_mask # bsz * 1 + """ + if retention_mask is not None: + retention_mask = retention_mask.float().view(-1, 1, 1, 1).to(decay) + else: + retention_mask = torch.ones(k.size(0), 1, 1, 1).to(decay) + # (b, h, v_dim, qk_dim) + current_kv = k * v.transpose(-1, -2) * retention_mask + + if past_key_value is not None and "prev_key_value" in past_key_value: + prev_kv = past_key_value["prev_key_value"] + prev_scale = past_key_value["scale"] + scale = torch.where(retention_mask == 0, prev_scale, prev_scale * decay + 1) + # connect prev_kv and current_kv + # how much to decay prev_kv + decay_amount = prev_scale.sqrt() * decay / scale.sqrt() + decay_amount = torch.where(retention_mask == 0, 1, decay_amount) + prev_kv = prev_kv * decay_amount # decay prev_kv + current_kv = current_kv / scale.sqrt() # scale current_kv + current_kv = torch.nan_to_num( + current_kv, nan=0.0 + ) # remove nan, scale might be 0 + + current_kv = prev_kv + current_kv + else: + scale = torch.ones_like(decay) + # when retention_mask is 0 at the beginning, setting scale to 1 will + # make the first retention to use the padding incorrectly. Hence, + # setting it to 0 here. This is a little ugly, so we might want to + # change this later. TODO: improve + scale = torch.where(retention_mask == 0, torch.zeros_like(decay), scale) + + output = torch.sum(q * current_kv, dim=3).unsqueeze(1) # (b, 1, h, d_v) + + cache = {"prev_key_value": current_kv, "scale": scale} + return output, cache + + def chunkwise_retention(self, q, k, v, decay_mask): + """ + q, k, v, # bsz * num_head * seqlen * qkv_dim + past_key_value: + - "prev_key_value" # bsz * num_head * v_dim * qk_dim + - "scale" # (1 or bsz) * num_head * 1 * 1 + decay_mask, # 1 * num_head * chunk_size * chunk_size + cross_decay, # 1 * num_head * 1 * 1 + inner_decay, # 1 * num_head * chunk_size * 1 + """ + # TODO: not working properly + ( + decay_mask, + cross_decay, + query_inner_decay, + value_inner_decay, + decay_scale, + ) = decay_mask + bsz, _, tgt_len, _ = v.size() + chunk_len = decay_mask.size(-1) + assert tgt_len % chunk_len == 0 + num_chunks = tgt_len // chunk_len + + # [b, n_c, h, t_c, qkv_dim] + q = q.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose( + 1, 2 + ) + k = k.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose( + 1, 2 + ) + v = v.view(bsz, self.num_heads, num_chunks, chunk_len, self.head_dim).transpose( + 1, 2 + ) + + k_t = k.transpose(-1, -2) + + qk_mat = q @ k_t # [b, n_c, h, t_c, t_c] + qk_mat = qk_mat * decay_mask.unsqueeze(1) + inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1) + qk_mat = qk_mat / inner_scale + # [b, n_c, h, t_c, v_dim] + inner_output = torch.matmul(qk_mat, v) + + # reduce kv in one chunk + # [b, n_c, h, qk_dim, v_dim] + kv = k_t @ (v * value_inner_decay) + # kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim) + + kv_recurrent = [] + cross_scale = [] + kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v) + kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v) + + # accumulate kv by loop + for i in range(num_chunks): + kv_recurrent.append(kv_state / kv_scale) + cross_scale.append(kv_scale) + kv_state = kv_state * cross_decay + kv[:, i] + kv_scale = ( + kv_state.detach() + .abs() + .sum(dim=-2, keepdim=True) + .max(dim=-1, keepdim=True) + .values.clamp(min=1) + ) + + kv_recurrent = torch.stack(kv_recurrent, dim=1) + cross_scale = torch.stack(cross_scale, dim=1) + + all_scale = torch.maximum(inner_scale, cross_scale) + align_inner_scale = all_scale / inner_scale + align_cross_scale = all_scale / cross_scale + + cross_output = (q * query_inner_decay.unsqueeze(1)) @ kv_recurrent + output = inner_output / align_inner_scale + cross_output / align_cross_scale + output = output.transpose(2, 3) # [b, n_c, t_c, h, v_dim] + + cache = {"prev_key_value": kv_state.transpose(-2, -1), "scale": decay_scale} + return output, cache + + def forward( + self, + hidden_states: torch.Tensor, + rel_pos: Tuple[Tuple[torch.Tensor]], + retention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + forward_impl: str = "parallel", + output_retentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: + B, T, H = hidden_states.size() + (sin, cos), decay_mask = rel_pos + # projections + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + g = self.g_proj(hidden_states) + # multi-head + q, k, v = split_heads((q, k, v), B, T, self.num_heads) + k *= self.scaling # for scaled dot product + # rotate + # NOTE: theta_shift has bug with mps device. + qr = theta_shift(q, sin, cos) + kr = theta_shift(k, sin, cos) + + # retention + if forward_impl == "parallel": + retention_out, curr_kv, retention_weights = self.parallel_retention( + qr, kr, v, decay_mask + ) + elif forward_impl == "recurrent": + retention_out, curr_kv = self.recurrent_retention( + qr, + kr, + v, + decay_mask, + past_key_value=past_key_value, + retention_mask=retention_mask, + ) + elif forward_impl == "chunkwise": + retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask) + else: + raise ValueError(f"forward_impl {forward_impl} not supported.") + + # concaat heads + normed = self.group_norm(retention_out).reshape(B, T, self.value_dim) + # out gate & proj + out = self.gate_fn(g) * normed + out = self.out_proj(out) + + outputs = (out, curr_kv) + if output_retentions: + outputs += (retention_weights,) if forward_impl == "parallel" else (None,) + return outputs + + +class FeedForwardNetwork(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + layernorm_eps, + subln=False, + use_rms_norm=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = nn.Linear(self.embed_dim, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim) + if subln: + if use_rms_norm: + self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps) + else: + self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps) + else: + self.ffn_layernorm = None + + def reset_parameters(self): + self.fc1.reset_parameters() + self.fc2.reset_parameters() + if self.ffn_layernorm is not None: + self.ffn_layernorm.reset_parameters() + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + x = self.fc1(x) + x = self.activation_fn(x.float()).type_as(x) + x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + x = x.view(x_shape) + x = self.dropout_module(x) + return x + + +class GLU(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False) + self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False) + + def reset_parameters(self): + self.fc1.reset_parameters() + self.fc2.reset_parameters() + self.gate.reset_parameters() + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + g = self.gate(x) + x = self.fc1(x) + x = self.activation_fn(x.float()).type_as(x) * g + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = x.view(x_shape) + x = self.dropout_module(x) + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return "p={}".format(self.drop_prob) + + +class RetNetDecoderLayer(nn.Module): + def __init__(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False): + super().__init__() + self.config = config + self.embed_dim = config.decoder_embed_dim + self.dropout_module = torch.nn.Dropout(config.dropout) + + if config.drop_path_rate > 0: + drop_path_prob = np.linspace( + 0, config.drop_path_rate, config.decoder_layers + )[depth] + self.drop_path = DropPath(drop_path_prob) + else: + self.drop_path = None + + self.retention = MultiScaleRetention( + config, use_bias=False, tensor_parallel=tensor_parallel + ) + + self.normalize_before = config.decoder_normalize_before + + self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) + + self.ffn_dim = config.decoder_ffn_embed_dim + + self.ffn = self.build_ffn() + + self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) + + if config.deepnorm: + self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) + else: + self.alpha = 1.0 + + def build_ffn(self): + if self.config.use_glu: + return GLU( + self.embed_dim, + self.ffn_dim, + self.config.activation_fn, + self.config.dropout, + self.config.activation_dropout, + ) + else: + return FeedForwardNetwork( + self.embed_dim, + self.ffn_dim, + self.config.activation_fn, + self.config.dropout, + self.config.activation_dropout, + self.config.layernorm_eps, + self.config.subln, + self.config.use_ffn_rms_norm, + ) + + def residual_connection(self, x, residual): + return residual * self.alpha + x + + def forward( + self, + hidden_states: torch.Tensor, + retention_rel_pos: Tuple[Tuple[torch.Tensor]], + retention_mask: Optional[torch.Tensor] = None, + forward_impl: str = "parallel", + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_retentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: + residual = hidden_states + if self.normalize_before: + hidden_states = self.retention_layer_norm(hidden_states) + + msr_outs = self.retention( + hidden_states, + retention_rel_pos, + retention_mask=retention_mask, + past_key_value=past_key_value, + forward_impl=forward_impl, + output_retentions=output_retentions, + ) + hidden_states = msr_outs[0] + curr_kv = msr_outs[1] + + hidden_states = self.dropout_module(hidden_states) + + if self.drop_path is not None: + hidden_states = self.drop_path(hidden_states) + + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.retention_layer_norm(hidden_states) + + residual = hidden_states + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + + if self.drop_path is not None: + hidden_states = self.drop_path(hidden_states) + + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, curr_kv) + + if output_retentions: + outputs += (msr_outs[2],) + return outputs + + +class RetNetPreTrainedModel(PreTrainedModel): + # copied from LlamaPretrainedModel + config_class = RetNetConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["RetNetDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + """ + Following original retnet, weights are already initialized in their own + ways within their own init. + """ + pass + # below is copied from LlamaPretrainedModel + # std = self.config.initializer_range + # if isinstance(module, nn.Linear): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.bias is not None: + # module.bias.data.zero_() + # elif isinstance(module, nn.Embedding): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.padding_idx is not None: + # module.weight.data[module.padding_idx].zero_() + + +@dataclass +class RetNetOutputWithPast(ModelOutput): + """ + class for RetNet model's outputs that may also contain a past key/values (to speed up sequential decoding). + + config: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, decoder_embed_dim)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + decoder_embed_dim)` is output. + past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim) + - "scale": shape=((1 or bsz) * num_head * 1 * 1) + + Contains pre-computed hidden-states (key and values in the multi-scale retention blocks) + that can be used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Retentions weights, used for visualization. + + attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + retentions: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class RetNetModel(RetNetPreTrainedModel): + def __init__( + self, + config: RetNetConfig, + embed_tokens: nn.Embedding = None, + tensor_parallel: bool = False, + ): + super().__init__(config) + self.config = config + + self.dropout_module = torch.nn.Dropout(config.dropout) + + self.embed_dim = config.decoder_embed_dim + self.embed_scale = ( + 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) + ) + + if embed_tokens is None: + embed_tokens = nn.Embedding( + config.vocab_size, config.decoder_embed_dim, config.pad_token_id + ) + self.embed_tokens = embed_tokens + + if config.layernorm_embedding: + self.layernorm_embedding = RMSNorm(self.embed_dim, eps=config.layernorm_eps) + else: + self.layernorm_embedding = None + + self.layers = nn.ModuleList([]) + + for i in range(config.decoder_layers): + self.layers.append( + RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel) + ) + + self.decoder_layers = len(self.layers) + + if config.decoder_normalize_before: + self.layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps) + else: + self.layer_norm = None + + self.retnet_rel_pos = RetNetRelPos(config) + self.recurrent_chunk_size = config.recurrent_chunk_size + + if config.deepnorm: + init_scale = math.pow(8.0 * config.decoder_layers, 0.25) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.div_(init_scale) + + if config.subln and not config.use_glu: + init_scale = math.sqrt(math.log(config.decoder_layers * 2)) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.mul_(init_scale) + + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward_embedding( + self, + input_ids, + forward_impl, + inputs_embeds=None, + past_key_values=None, + ): + # if past_key_values is not None: + if forward_impl == "recurrent": + input_ids = input_ids[:, -1:] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed = self.embed_scale * inputs_embeds + + if self.layernorm_embedding is not None: + embed = self.layernorm_embedding(embed) + + embed = self.dropout_module(embed) + + return embed + + def forward( + self, + input_ids: torch.LongTensor = None, + retention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_retentions: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + forward_impl: Optional[str] = "parallel", + recurrent_chunk_size: Optional[int] = None, + retention_rel_pos: Optional[Tuple[torch.Tensor]] = None, + ) -> Union[Tuple, RetNetOutputWithPast]: + if output_retentions is None and output_attentions is not None: + output_retentions = output_attentions + output_retentions = ( + output_retentions + if output_retentions is not None + else self.config.output_retentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # embed tokens + if inputs_embeds is None: + inputs_embeds = self.forward_embedding( + input_ids, forward_impl, inputs_embeds, past_key_values + ) + + if retention_mask is None and attention_mask is not None: + retention_mask = attention_mask + if retention_mask is not None and forward_impl == "recurrent": + retention_mask = retention_mask[:, -1:] + + hidden_states = inputs_embeds + + # handling chunking here + if recurrent_chunk_size is None: + recurrent_chunk_size = self.recurrent_chunk_size + need_pad_for_chunkwise = ( + forward_impl == "chunkwise" and seq_length % recurrent_chunk_size != 0 + ) + if need_pad_for_chunkwise: + padding_len = recurrent_chunk_size - seq_length % recurrent_chunk_size + slen = seq_length + padding_len + hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len)) + else: + slen = seq_length + # relative position + if retention_rel_pos is None: + retention_rel_pos = self.retnet_rel_pos( + slen, + forward_impl=forward_impl, + recurrent_chunk_size=recurrent_chunk_size, + retention_mask=retention_mask, + get_decay_scale=not self.training, + ) + + # start running through the decoder layers + all_hidden_states = () if output_hidden_states else None + all_retentions = () if output_retentions else None + # layers * [bsz, num_head, qk_dim, decoder_embed_dim] + next_decoder_cache = () if use_cache else None + + for idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_retentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + retention_rel_pos, + retention_mask, + forward_impl, + past_key_value, + ) + else: + layer_outputs = layer( + hidden_states, + retention_rel_pos, + retention_mask=retention_mask, + forward_impl=forward_impl, + past_key_value=past_key_value, + output_retentions=output_retentions, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[1],) + + if output_retentions: + all_retentions += (layer_outputs[2],) + + next_cache = next_decoder_cache if use_cache else None + + if need_pad_for_chunkwise: + hidden_states = hidden_states[:, :seq_length, :] + + if self.layer_norm is not None: + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_retentions] + if v is not None + ) + return RetNetOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + retentions=all_retentions, + attentions=all_retentions, + ) + + +@dataclass +class RetNetCausalLMOutputWithPast(ModelOutput): + """ + class for RetNet causal language model (or autoregressive) outputs. + + config: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim) + - "scale": shape=((1 or bsz) * num_head * 1 * 1) + + Contains pre-computed hidden-states (key and values in the multi-scale retention blocks) + that can be used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Retentions weights, used for visualization. + + attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + retentions: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class RetNetForCausalLM(RetNetPreTrainedModel): + def __init__( + self, + config: RetNetConfig, + embed_tokens: nn.Embedding = None, + tensor_parallel: bool = False, + ) -> None: + super().__init__(config) + self.model = RetNetModel( + config, embed_tokens=embed_tokens, tensor_parallel=tensor_parallel + ) + self.lm_head = nn.Linear( + config.decoder_embed_dim, config.vocab_size, bias=False + ) + # init here + torch.nn.init.normal_( + self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5 + ) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + retention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_retentions: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + forward_impl: Optional[str] = None, + recurrent_chunk_size: Optional[int] = None, + retention_rel_pos: Optional[Tuple[torch.Tensor]] = None, + ) -> Union[Tuple, RetNetCausalLMOutputWithPast]: + if output_retentions is None and output_attentions is not None: + output_retentions = output_attentions + output_retentions = ( + output_retentions + if output_retentions is not None + else self.config.output_retentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + forward_impl = ( + forward_impl if forward_impl is not None else self.config.forward_impl + ) + recurrent_chunk_size = ( + recurrent_chunk_size + if recurrent_chunk_size is not None + else self.config.recurrent_chunk_size + ) + + if retention_mask is None and attention_mask is not None: + retention_mask = attention_mask + + outputs = self.model( + input_ids, + retention_mask=retention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_retentions=output_retentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + forward_impl=forward_impl, + use_cache=use_cache, + recurrent_chunk_size=recurrent_chunk_size, + retention_rel_pos=retention_rel_pos, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if self.config.z_loss_coeff > 0: + # z_loss from PaLM paper + # z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits)) + z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean() + loss += self.config.z_loss_coeff * z_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return RetNetCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + retentions=outputs.retentions, + attentions=outputs.retentions, + ) + + def _crop_past_key_values(model, past_key_values, maximum_length): + """Since retnet's kv do not have length, no need to crop. Just return""" + return past_key_values + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + forward_impl = kwargs.get("forward_impl", "parallel") + if past_key_values is not None: + forward_impl = "recurrent" + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "forward_impl": forward_impl, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: # dict + layer_past_kv = layer_past["prev_key_value"] # [b, h, v_dim / h, qk_dim] + layer_past_scale = layer_past["scale"] # [b, h, 1, 1] + if layer_past_scale.size(0) > 1: + # this means that retention_mask is not None, so the scale for + # each batch is different. We need to select the correct scale then. + # NOTE: during huggingface generate, it will generate attention_mask + # if it is None, so this linke will always be true. Still, having + # this line here for safety. + layer_past_scale = layer_past_scale.index_select(0, beam_idx) + reordered_past += ( + { + "prev_key_value": layer_past_kv.index_select(0, beam_idx), + "scale": layer_past_scale, + }, + ) + return reordered_past + + def sample_token(self, logit, do_sample=False, top_k=1, top_p=1.0, temperature=1.0): + if not do_sample: + return torch.argmax(logit, dim=-1, keepdim=True) + filtered = top_k_top_p_filtering(logit / temperature, top_k=top_k, top_p=top_p) + return torch.multinomial(torch.softmax(filtered, dim=-1), num_samples=1) + + @torch.inference_mode() + def custom_generate( + self, + input_ids: torch.LongTensor = None, + retention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + parallel_compute_prompt=True, + max_new_tokens=20, + bos_token_id=0, + eos_token_id=0, + do_sample=False, + top_k=0, + top_p=1.0, + temperature=1.0, + early_stopping=True, + ): + if retention_mask is None and attention_mask is not None: + retention_mask = attention_mask + + if input_ids is not None: + if input_ids.shape[1] == 1: + past_key_values = None + elif parallel_compute_prompt: + ret_mask = ( + retention_mask[:, :-1] if retention_mask is not None else None + ) + outputs = self( + input_ids[:, :-1], + retention_mask=ret_mask, + forward_impl="parallel", + return_dict=True, + use_cache=True, + ) + past_key_values = outputs.past_key_values + else: + past_key_values = None + for p_i in range(input_ids.shape[1] - 1): + ret_mask = ( + retention_mask[:, : p_i + 1] + if retention_mask is not None + else None + ) + outputs = self( + input_ids[:, : p_i + 1], + retention_mask=ret_mask, + forward_impl="recurrent", + past_key_values=past_key_values, + return_dict=True, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + generated = input_ids + else: + generated = torch.tensor([[bos_token_id]]).to(self.lm_head.weight.device) + past_key_values = None + + for i in range(max_new_tokens): + outputs = self( + generated, + retention_mask=retention_mask, + forward_impl="recurrent", + past_key_values=past_key_values, + use_cache=True, + return_dict=True, + ) + logit = outputs.logits[:, -1, :] # [batch_size, vocab_size] + past_key_values = outputs.past_key_values + token = self.sample_token( + logit, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + generated = torch.cat([generated, token], dim=-1) + if retention_mask is not None: + retention_mask = torch.cat( + [retention_mask, torch.ones_like(token)], dim=-1 + ) + if early_stopping and (token == eos_token_id).all(): + break + return generated + + +class RetNetForSequenceClassification(RetNetPreTrainedModel): + def __init__(self, config, tensor_parallel=False): + super().__init__(config) + self.num_labels = config.num_labels + self.model = RetNetModel(config, tensor_parallel=tensor_parallel) + self.score = nn.Linear(config.decoder_embed_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + retention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_retentions: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + forward_impl: Optional[str] = None, + recurrent_chunk_size: Optional[int] = None, + retention_rel_pos: Optional[Tuple[torch.Tensor]] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + if output_retentions is None and output_attentions is not None: + output_retentions = output_attentions + output_retentions = ( + output_retentions + if output_retentions is not None + else self.config.output_retentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + forward_impl = ( + forward_impl if forward_impl is not None else self.config.forward_impl + ) + recurrent_chunk_size = ( + recurrent_chunk_size + if recurrent_chunk_size is not None + else self.config.recurrent_chunk_size + ) + + if retention_mask is None and attention_mask is not None: + retention_mask = attention_mask + + outputs = self.model( + input_ids, + retention_mask=retention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_retentions=output_retentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + forward_impl=forward_impl, + use_cache=use_cache, + recurrent_chunk_size=recurrent_chunk_size, + retention_rel_pos=retention_rel_pos, + ) + + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 8874d0b..0e4566e 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -20,7 +20,9 @@ def get_model(cfg, training=True): n_layers=cfg.layers, n_experts=cfg.experts, - training=training, + l_padding = cfg.input_alignment, + + training = training, config = cfg, ) model._cfg = cfg diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3691eca..06dc3e4 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -300,7 +300,7 @@ class AR_NAR(Base): def example_usage(): - cfg.trainer.backend = "local" + #cfg.trainer.backend = "local" from functools import partial from einops import repeat @@ -317,7 +317,7 @@ def example_usage(): def tokenize(content, lang_marker="en"): split = content.split(" ") phones = [f""] + [ " " if not p else p for p in split ] + [f""] - return torch.tensor([*map(symmap.get, phones)]).to() + return torch.tensor([*map(symmap.get, phones)]) qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) @@ -344,6 +344,8 @@ def example_usage(): 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 'n_experts': 1, + + 'l_padding': 8, } """ kwargs = { @@ -366,6 +368,7 @@ def example_usage(): steps = 500 optimizer = ml.Prodigy(model.parameters(), lr=1.0) #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) + engine = Engine(model=model, optimizer=optimizer) # copy embeddings if requested @@ -392,15 +395,15 @@ def example_usage(): param.requires_grad_(False) engine._frozen_params.add(param) - if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: - model.model = ml.replace_linear( model.model ) +# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: + model.model = ml.replace_linear( model.model ) torch.save( { 'module': model.state_dict() }, "./data/test.pth" ) print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - + @torch.inference_mode() def sample( name, steps=600 ): engine.eval() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e008d91..de164e0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -29,6 +29,14 @@ except Exception as e: print("Error importing `retnet` arch:", e) pass +from .retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF +""" +try: +except Exception as e: + print("Error importing `retnet-hf` arch:", e) + pass +""" + try: from transformers import LlamaModel, LlamaConfig except Exception as e: @@ -44,6 +52,7 @@ except Exception as e: try: from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm + # override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable class BitNetTransformer(nn.Module): def __init__( self, @@ -159,7 +168,6 @@ class Embedding(nn.Embedding): def forward(self, x_list: list[Tensor]) -> list[Tensor]: if len(x_list) == 0: return [] - return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) class MultiEmbedding(nn.Module): @@ -308,7 +316,9 @@ class Base(nn.Module): n_layers: int = 12, p_dropout: float = 0.1, - n_experts: int=1, + n_experts: int = 1, + + l_padding: int = 0, training = True, config = None, @@ -323,6 +333,8 @@ class Base(nn.Module): self.n_heads = n_heads self.n_layers = n_layers self.n_experts = n_experts + + self.l_padding = l_padding # +1 to include the stop token # to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding @@ -460,6 +472,27 @@ class Base(nn.Module): )) self.model = RetNetDecoder(RetNetConfig(**kwargs)) + elif self.arch_type == "retnet-hf": + kwargs = dict( + vocab_size=n_resp_tokens, + decoder_embed_dim=d_model, + decoder_value_embed_dim =d_model * 2, + decoder_retention_heads=n_heads, + decoder_ffn_embed_dim=d_model * 4, + decoder_layers=n_layers, + dropout=p_dropout if training else 0.0, + checkpoint_activations=self.activation_checkpointing, + activation_fn="gelu", + use_glu=False, # self.version >= 3, + + recurrent_chunk_size=self.recurrent_chunk_size if self.causal else 0, + decoder_normalize_before=True, + + deepnorm=False, + subln=True, + ) + + self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs)) elif self.arch_type == "bitnet": self.model = BitNetTransformer( num_tokens=n_resp_tokens, @@ -514,19 +547,50 @@ class Base(nn.Module): sep=self.sep, ) + x, m = list_to_tensor(x_list) aux_loss = None device = x.device + + # pad our input and mask, but retain the original length by doing it after + if self.l_padding and x.shape[1] % self.l_padding != 0: + # pad input + shape = list(x.shape) + shape[1] = self.l_padding - shape[1] % self.l_padding + + padding = torch.zeros(shape, dtype=x.dtype, device=x.device) + x = torch.cat([x, padding], dim=1) + + # pad mask + shape[2] = 1 + padding = torch.zeros(shape, dtype=x.dtype, device=x.device) + m = torch.cat([m, padding], dim=1) if state is not None and self.arch_type == "retnet": # prefill if len(state) == 0: prefill_size = x.shape[1] + # run the initial prompt to fill the KV cache - for n in range(prefill_size): - xi = x[:, n, :].unsqueeze(1) - self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True) + if self.arch_type == "retnet": + for n in range(prefill_size): + xi = x[:, n, :].unsqueeze(1) + self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True) + elif self.arch_type == "retnet-hf": + for n in range(prefill_size): + xi = x[:, n, :].unsqueeze(1) + + kwargs = dict( + #attention_mask=m, + inputs_embeds=x, + past_key_values=state[-1], + use_cache=state is not None, + # return_dict=True, + ) + + out = self.model(**kwargs) + state.append(out.past_key_values) # grab last token(s) x = x[:, -1, :].unsqueeze(1) @@ -566,6 +630,21 @@ class Base(nn.Module): x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) if _ is not None and "l_aux" in _ and self.n_experts > 1: aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001 + elif self.arch_type == "retnet-hf": + kwargs = dict( + #attention_mask=m, + inputs_embeds=x, + past_key_values=state, + use_cache=False, #state is not None, + # return_dict=True, + ) + + t = self.model(**kwargs) + + x = t[0] + + if state is not None: + state = t[1] elif self.arch_type == "bitnet": x = self.model(x) # output projection layer with masking diff --git a/vall_e/models/retnet.py b/vall_e/models/retnet.py index a0c4f7c..00b9ff1 100755 --- a/vall_e/models/retnet.py +++ b/vall_e/models/retnet.py @@ -1,3 +1,46 @@ +# https://github.com/microsoft/torchscale + from torchscale.architecture.config import RetNetConfig from torchscale.architecture.retnet import RetNetDecoder -# from retnet import RetNet \ No newline at end of file +# from retnet import RetNet + +# override MultiScaleRetention's forward because training with te throws an error +from torchscale.component.multiscale_retention import MultiScaleRetention, theta_shift + +def MultiScaleRetention_forward( + self, + x, + rel_pos, + chunkwise_recurrent=False, + incremental_state=None + ): + bsz, tgt_len, _ = x.size() + (sin, cos), inner_mask = rel_pos + + q = self.q_proj(x) + k = self.k_proj(x) * self.scaling + v = self.v_proj(x) + g = self.g_proj(x) + + q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) + k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2) + + qr = theta_shift(q, sin, cos) + kr = theta_shift(k, sin, cos) + + if incremental_state is not None: + output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state) + elif chunkwise_recurrent: + output = self.chunk_recurrent_forward(qr, kr, v, inner_mask) + else: + output = self.parallel_forward(qr, kr, v, inner_mask) + + output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads) + + output = self.gate_fn(g) * output + + output = self.out_proj(output) + + return output + +MultiScaleRetention.forward = MultiScaleRetention_forward \ No newline at end of file diff --git a/vall_e/models/retnet_hf.py b/vall_e/models/retnet_hf.py new file mode 100644 index 0000000..12e0589 --- /dev/null +++ b/vall_e/models/retnet_hf.py @@ -0,0 +1,199 @@ +# https://github.com/syncdoth/RetNet/ +from ..ext.retnet_hf.configuration_retnet import RetNetConfig +from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder + +# things we're overriding or required to override +from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos + +import torch +import math +from typing import Dict, List, Optional, Tuple, Union + +# required to have compatibile LayerNorm +def FeedForwardNetwork_init( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + layernorm_eps, + subln=True, + use_rms_norm=False, +): + super(FeedForwardNetwork, self).__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim) + self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim) + self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None + +FeedForwardNetwork.__init__ = FeedForwardNetwork_init + +# removes embed_tokens +def RetNetModel_init( + self, + config: RetNetConfig, + embed_tokens: torch.nn.Embedding = None, + tensor_parallel: bool = False, + ): + super(RetNetDecoder, self).__init__(config) + self.config = config + + self.dropout_module = torch.nn.Dropout(config.dropout) + + self.embed_dim = config.decoder_embed_dim + self.embed_scale = ( + 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) + ) + + """ + if embed_tokens is None: + embed_tokens = torch.nn.Embedding( + config.vocab_size, config.decoder_embed_dim, config.pad_token_id + ) + """ + self.embed_tokens = None + + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layernorm_embedding = None + + self.layers = torch.nn.ModuleList([]) + + for i in range(config.decoder_layers): + self.layers.append( + RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel) + ) + + self.decoder_layers = len(self.layers) + + if config.decoder_normalize_before: + self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layer_norm = None + + self.retnet_rel_pos = RetNetRelPos(config) + self.recurrent_chunk_size = config.recurrent_chunk_size + + if config.deepnorm: + init_scale = math.pow(8.0 * config.decoder_layers, 0.25) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.div_(init_scale) + + if config.subln and not config.use_glu: + init_scale = math.sqrt(math.log(config.decoder_layers * 2)) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.mul_(init_scale) + + self.gradient_checkpointing = True + self.post_init() + +RetNetDecoder.__init__ = RetNetModel_init + +# restores bias in our FFNs +def RetNetDecoderLayer_init(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False): + super(RetNetDecoderLayer, self).__init__() + self.config = config + self.embed_dim = config.decoder_embed_dim + self.dropout_module = torch.nn.Dropout(config.dropout) + + if config.drop_path_rate > 0: + drop_path_prob = np.linspace( + 0, config.drop_path_rate, config.decoder_layers + )[depth] + self.drop_path = DropPath(drop_path_prob) + else: + self.drop_path = None + + self.retention = MultiScaleRetention( + config, use_bias=True, tensor_parallel=tensor_parallel + ) + + self.normalize_before = config.decoder_normalize_before + + self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + + self.ffn_dim = config.decoder_ffn_embed_dim + + self.ffn = self.build_ffn() + + self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + + if config.deepnorm: + self.alpha = math.pow(2.0 * config.decoder_layers, 0.25) + else: + self.alpha = 1.0 + +RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init +# fixes backwards when using te's autocast +def MultiScaleRetention_forward( + self, + hidden_states: torch.Tensor, + rel_pos: Tuple[Tuple[torch.Tensor]], + retention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + forward_impl: str = "parallel", + output_retentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]: + B, T, H = hidden_states.size() + (sin, cos), decay_mask = rel_pos + # projections + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) * self.scaling # for scaled dot product + v = self.v_proj(hidden_states) + g = self.g_proj(hidden_states) + # multi-head + q, k, v = split_heads((q, k, v), B, T, self.num_heads) + + # rotate + # NOTE: theta_shift has bug with mps device. + qr = theta_shift(q, sin, cos) + kr = theta_shift(k, sin, cos) + + # retention + if forward_impl == "parallel": + retention_out, curr_kv, retention_weights = self.parallel_retention( + qr, kr, v, decay_mask + ) + elif forward_impl == "recurrent": + retention_out, curr_kv = self.recurrent_retention( + qr, + kr, + v, + decay_mask, + past_key_value=past_key_value, + retention_mask=retention_mask, + ) + elif forward_impl == "chunkwise": + retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask) + else: + raise ValueError(f"forward_impl {forward_impl} not supported.") + + # concaat heads + normed = self.group_norm(retention_out).reshape(B, T, self.value_dim) + # out gate & proj + out = self.gate_fn(g) * normed + out = self.out_proj(out) + + outputs = (out, curr_kv) + if output_retentions: + outputs += (retention_weights,) if forward_impl == "parallel" else (None,) + return outputs + +MultiScaleRetention.forward = MultiScaleRetention_forward \ No newline at end of file diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 6ca48e2..3f92f29 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -75,6 +75,19 @@ def autocast_forward( func ): return wrapper Embedding.forward = autocast_forward(Embedding.forward) +if cfg.fp8.enabled: + import transformer_engine.pytorch as te + + Linear = te.Linear + + @contextmanager + def autocast(): + yield te.fp8_autocast(enabled=True) +else: + @contextmanager + def autocast(): + yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) + if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: torch.nn.Linear = Linear torch.nn.Embedding = Embedding @@ -83,6 +96,7 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: torch.optim.AdamW = AdamW torch.optim.SGD = SGD + # disgusting kludge, but it works (just realized BitNet has its own replacement routine) def replace_linear( model ): device = next(model.parameters()).device