backwards compat for old YAMLs with models, option to set flash attention 2 for Llama (and derivatives), included syncdoth/RetNets torchscale retnet for shits and grins, etc.

This commit is contained in:
mrq 2024-04-16 10:02:31 -05:00
parent 545162195b
commit aa1e25fbf5
10 changed files with 1125 additions and 5 deletions

View File

@ -179,6 +179,7 @@ class Model:
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "eager" # or flash_attention_2
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@ -528,6 +529,7 @@ class Config(_Config):
dataset: Dataset = field(default_factory=lambda: Dataset)
model: Model = field(default_factory=lambda: Model)
models: dict | list | None = None # deprecated
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
@ -576,7 +578,10 @@ class Config(_Config):
def format( self ):
self.dataset = Dataset(**self.dataset)
self.model = Model(**self.model)
if self.models is not None:
self.model = Model(**next(iter(self.models)))
else:
self.model = Model(**self.model)
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)

View File

@ -0,0 +1,4 @@
# from https://github.com/syncdoth/RetNet/
# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system
# this is included because torchscale's implementation recently changed and I don't want to keep maintaining a fork

View File

@ -9,7 +9,14 @@ import torch.utils.checkpoint
from timm.models.layers import drop_path
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import top_k_top_p_filtering
try:
from transformers import top_k_top_p_filtering
except Exception as e:
pass
try:
from transformers.generation.utils import top_k_top_p_filtering
except Exception as e:
pass
from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging

View File

View File

@ -0,0 +1,74 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
class RetNetConfig(object):
def __init__(self, **kwargs):
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.use_ffn_rms_norm = kwargs.pop("use_ffn_rms_norm", False)
self.use_glu = kwargs.pop("use_glu", True)
self.use_lm_decay = kwargs.pop("use_lm_decay", False)
self.z_loss_coeff = kwargs.pop("z_loss_coeff", 0.0) # TODO: 1e-4
self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed",
False)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", True)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
# Blockwise
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
# token id
self.pad_token_id = kwargs.pop("pad_token_id", 0)
self.postprocessing()
def postprocessing(self):
if self.deepnorm:
self.decoder_normalize_before = False
self.subln = False
if self.subln:
self.decoder_normalize_before = True
self.deepnorm = False
if self.use_xmoe:
self.moe_normalize_gate_prob_before_dropping = True
self.moe_second_expert_policy = "random"
assert self.moe_freq > 0 and self.moe_expert_count > 0
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)
self.postprocessing()

View File

@ -0,0 +1,746 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# from fairscale.nn import checkpoint_wrapper, wrap
from timm.models.layers import drop_path
from transformers.modeling_outputs import CausalLMOutputWithPast
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
def theta_shift(x, sin, cos):
return (x * cos) + (rotate_every_two(x) * sin)
def get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
elif activation == "swish":
return F.silu
else:
raise NotImplementedError
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
class RetNetRelPos(nn.Module):
def __init__(self, config):
super().__init__()
num_heads = config.decoder_retention_heads
angle = 1.0 / (10000**torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2))
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
if config.use_lm_decay:
# NOTE: alternative way described in the paper
s = torch.log(torch.tensor(1 / 32))
e = torch.log(torch.tensor(1 / 512))
decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,]
else:
decay = torch.log(1 - 2**(-5 - torch.arange(num_heads, dtype=torch.float)))
self.register_buffer("angle", angle)
self.register_buffer("decay", decay)
self.recurrent_chunk_size = config.recurrent_chunk_size
def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False):
if activate_recurrent:
sin = torch.sin(self.angle * (slen - 1))
cos = torch.cos(self.angle * (slen - 1))
retention_rel_pos = ((sin, cos), self.decay.exp())
elif chunkwise_recurrent:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
block_index = torch.arange(self.recurrent_chunk_size).to(self.decay)
mask = torch.tril(torch.ones(self.recurrent_chunk_size,
self.recurrent_chunk_size)).to(self.decay)
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(),
float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
value_inner_decay = value_inner_decay.unsqueeze(-1)
scale = mask.sum(dim=-1, keepdim=True).sqrt()
inner_mask = mask / scale
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
query_inner_decay = query_inner_decay[:, :, None] / (
scale / mask[:, -1].sum(dim=-1)[:, None, None])
cross_decay = cross_decay[:, None, None]
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay,
value_inner_decay))
else:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
mask = torch.tril(torch.ones(slen, slen)).to(self.decay)
mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
retention_rel_pos = ((sin, cos), mask)
return retention_rel_pos
class MultiScaleRetention(nn.Module):
def __init__(
self,
config,
embed_dim,
value_dim,
num_heads,
gate_fn="swish",
):
super().__init__()
self.config = config
self.embed_dim = embed_dim
self.value_dim = value_dim
self.num_heads = num_heads
self.head_dim = self.value_dim // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim**-0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, value_dim, bias=False)
self.g_proj = nn.Linear(embed_dim, value_dim, bias=False)
self.out_proj = nn.Linear(value_dim, embed_dim, bias=False)
self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5)
nn.init.xavier_uniform_(self.out_proj.weight, gain=2**-1)
def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.size()
vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
qk_mat = qk_mat * mask
# invariant after normalization
qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4)
output = torch.matmul(qk_mat, vr)
output = output.transpose(1, 2)
return output
def recurrent_forward(self, qr, kr, v, decay, incremental_state):
bsz = v.size(0)
v = v.view(bsz, self.num_heads, self.head_dim, 1)
kv = kr * v
if "prev_key_value" in incremental_state:
prev_kv = incremental_state["prev_key_value"]
prev_scale = incremental_state["scale"]
scale = prev_scale * decay + 1
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(
self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
else:
scale = torch.ones_like(decay)
incremental_state["prev_key_value"] = kv
incremental_state["scale"] = scale
output = torch.sum(qr * kv, dim=3)
return output
def chunk_recurrent_forward(self, qr, kr, v, inner_mask):
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
bsz, tgt_len, embed_dim = v.size()
chunk_len = mask.size(1)
num_chunks = tgt_len // chunk_len
assert tgt_len % chunk_len == 0
qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
kr_t = kr.transpose(-1, -2)
qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
qk_mat = qk_mat * mask
inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
qk_mat = qk_mat / inner_scale
inner_output = torch.matmul(qk_mat,
v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
# reduce kv in one chunk
kv = kr_t @ (v * value_inner_decay)
kv_recurrent = []
cross_scale = []
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
# accumulate kv by loop
for i in range(num_chunks):
kv_recurrent.append(kv_state / kv_scale)
cross_scale.append(kv_scale)
kv_state = kv_state * cross_decay + kv[:, i]
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(
dim=-1, keepdim=True).values.clamp(min=1)
kv_recurrent = torch.stack(kv_recurrent, dim=1)
cross_scale = torch.stack(cross_scale, dim=1)
all_scale = torch.maximum(inner_scale, cross_scale)
align_inner_scale = all_scale / inner_scale
align_cross_scale = all_scale / cross_scale
cross_output = (qr * query_inner_decay) @ kv_recurrent
output = inner_output / align_inner_scale + cross_output / align_cross_scale
# output = inner_output / cross_scale + cross_output / inner_scale
output = output.transpose(2, 3)
return output
def forward(self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None):
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
g = self.g_proj(x)
k *= self.scaling
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output
class FeedForwardNetwork(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
layernorm_eps,
subln=False,
use_rms_norm=False,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
if subln:
if use_rms_norm:
self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = None
def reset_parameters(self):
self.fc1.reset_parameters()
self.fc2.reset_parameters()
if self.ffn_layernorm is not None:
self.ffn_layernorm.reset_parameters()
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x.size(-1))
x = self.fc1(x)
x = self.activation_fn(x.float()).type_as(x)
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = x.view(x_shape)
x = self.dropout_module(x)
return x
class GLU(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
def reset_parameters(self):
self.fc1.reset_parameters()
self.fc2.reset_parameters()
self.gate.reset_parameters()
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x.size(-1))
g = self.gate(x)
x = self.fc1(x)
x = self.activation_fn(x.float()).type_as(x) * g
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = x.view(x_shape)
x = self.dropout_module(x)
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self):
return "p={}".format(self.drop_prob)
class RetNetDecoderLayer(nn.Module):
def __init__(
self,
config,
depth,
):
super().__init__()
self.config = config
self.embed_dim = config.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(config.dropout)
if config.drop_path_rate > 0:
drop_path_prob = np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.retention = MultiScaleRetention(
config,
self.embed_dim,
config.decoder_value_embed_dim,
config.decoder_retention_heads,
)
self.normalize_before = config.decoder_normalize_before
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
self.ffn_dim = config.decoder_ffn_embed_dim
self.ffn = self.build_ffn()
self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
if config.deepnorm:
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
else:
self.alpha = 1.0
def build_ffn(self):
if self.config.use_glu:
return GLU(
self.embed_dim,
self.ffn_dim,
self.config.activation_fn,
self.config.dropout,
self.config.activation_dropout,
)
else:
return FeedForwardNetwork(
self.embed_dim,
self.ffn_dim,
self.config.activation_fn,
self.config.dropout,
self.config.activation_dropout,
self.config.layernorm_eps,
self.config.subln,
self.config.use_ffn_rms_norm,
)
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(
self,
x,
incremental_state=None,
chunkwise_recurrent=False,
retention_rel_pos=None,
):
residual = x
if self.normalize_before:
x = self.retention_layer_norm(x)
x = self.retention(
x,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.retention_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.ffn(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
class RetNetModel(nn.Module):
def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout_module = torch.nn.Dropout(config.dropout)
embed_dim = config.decoder_embed_dim
self.embed_dim = embed_dim
self.embed_scale = 1.0 if config.no_scale_embedding else math.sqrt(embed_dim)
self.embed_tokens = embed_tokens
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
self.output_projection = self.build_output_projection(config)
else:
self.output_projection = output_projection
if config.layernorm_embedding:
self.layernorm_embedding = RMSNorm(embed_dim, eps=config.layernorm_eps)
else:
self.layernorm_embedding = None
self.layers = nn.ModuleList([])
for i in range(config.decoder_layers):
self.layers.append(self.build_decoder_layer(
config,
depth=i,
))
self.num_layers = len(self.layers)
if config.decoder_normalize_before:
self.layer_norm = RMSNorm(embed_dim, eps=config.layernorm_eps)
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(config)
self.chunkwise_recurrent = config.chunkwise_recurrent
self.recurrent_chunk_size = config.recurrent_chunk_size
if config.deepnorm:
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
for name, p in self.named_parameters():
if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name):
p.data.div_(init_scale)
if config.subln and not config.use_glu:
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
for name, p in self.named_parameters():
if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name):
p.data.mul_(init_scale)
def build_output_projection(
self,
config,
):
if config.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = self.embed_tokens.weight
else:
output_projection = torch.nn.Linear(config.decoder_embed_dim,
config.vocab_size,
bias=False)
torch.nn.init.normal_(output_projection.weight,
mean=0,
std=config.decoder_embed_dim**-0.5)
return output_projection
def build_decoder_layer(self, config, depth):
layer = RetNetDecoderLayer(
config,
depth,
)
# if config.checkpoint_activations:
# layer = checkpoint_wrapper(layer)
# if config.fsdp:
# layer = wrap(layer)
return layer
def forward_embedding(
self,
tokens,
token_embedding=None,
incremental_state=None,
):
if incremental_state is not None and not self.is_first_step(incremental_state):
tokens = tokens[:, -1:]
if token_embedding is None:
token_embedding = self.embed_tokens(tokens)
x = embed = self.embed_scale * token_embedding
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward(self,
prev_output_tokens,
incremental_state=None,
features_only=False,
token_embeddings=None):
# embed tokens
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
is_first_step = self.is_first_step(incremental_state)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
padding_len = self.recurrent_chunk_size - prev_output_tokens.size(
1) % self.recurrent_chunk_size
slen = prev_output_tokens.size(1) + padding_len
x = F.pad(x, (0, 0, 0, padding_len))
else:
slen = prev_output_tokens.size(1)
# relative position
retention_rel_pos = self.retnet_rel_pos(slen,
incremental_state is not None and not is_first_step,
chunkwise_recurrent=self.chunkwise_recurrent)
# decoder layers
inner_states = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None or is_first_step:
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else:
if idx not in incremental_state:
incremental_state[idx] = {}
x = layer(
x,
incremental_state[idx] if incremental_state is not None else None,
retention_rel_pos=retention_rel_pos,
chunkwise_recurrent=self.chunkwise_recurrent,
)
inner_states.append(x)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
x = x[:, :prev_output_tokens.size(1), :]
if self.layer_norm is not None:
x = self.layer_norm(x)
if not features_only:
x = self.output_layer(x)
return x, {
"inner_states": inner_states,
"attn": None,
}
def output_layer(self, features):
return self.output_projection(features)
class RetNetForCausalLM(nn.Module):
def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs):
super().__init__(**kwargs)
assert config.vocab_size > 0, "you must specify vocab size"
if output_projection is None:
config.no_output_layer = False
if embed_tokens is None:
embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim,
config.pad_token_id)
self.config = config
self.model = RetNetModel(config,
embed_tokens=embed_tokens,
output_projection=output_projection,
**kwargs)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.model.output_projection
def set_output_embeddings(self, new_embeddings):
self.model.output_projection = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
retention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_retentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
recurrent_chunk_size: Optional[int] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
outputs = self.model(
input_ids,
incremental_state=past_key_values,
features_only=False,
token_embeddings=inputs_embeds,
)
logits, inner_hidden_states = outputs[0], outputs[1]['inner_states']
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if self.config.z_loss_coeff > 0:
# z_loss from PaLM paper
# z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits))
z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean()
loss += self.config.z_loss_coeff * z_loss
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=inner_hidden_states,
attentions=None,
)

View File

@ -351,7 +351,7 @@ def example_usage():
'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 16, # 32
'n_layers': 12, # 32
'n_experts': 1,
'l_padding': 8 if cfg.fp8.enabled else 0,

View File

@ -24,7 +24,8 @@ except Exception as e:
pass
try:
from .retnet import RetNetDecoder, RetNetConfig
#from .retnet import RetNetDecoder, RetNetConfig
from .retnet_ts import RetNetDecoder, RetNetConfig
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
@ -394,6 +395,7 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
))
else:
self.model = MixtralModel(MixtralConfig(
@ -412,6 +414,7 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
))
elif self.arch_type == "llama":
if n_experts <= 1:
@ -428,6 +431,7 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
))
else:
self.model = MixtralModel(MixtralConfig(
@ -446,6 +450,7 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
))
elif self.arch_type == "retnet":

279
vall_e/models/retnet_ts.py Normal file
View File

@ -0,0 +1,279 @@
# https://github.com/syncdoth/RetNet/
from ..ext.retnet_ts.config import RetNetConfig
from ..ext.retnet_ts.retnet import RetNetModel as RetNetDecoder
# things we're overriding or required to override
from ..ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
import torch
import math
from typing import Dict, List, Optional, Tuple, Union
from torch.utils.checkpoint import checkpoint
# required to have compatibile LayerNorm
def FeedForwardNetwork_init(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
layernorm_eps,
subln=True,
use_rms_norm=False,
):
super(FeedForwardNetwork, self).__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim)
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
# removes embed_tokens
def RetNetModel_init(
self, config, embed_tokens=None, output_projection=None, **kwargs
):
super(RetNetDecoder, self).__init__(**kwargs)
self.config = config
self.dropout_module = torch.nn.Dropout(config.dropout)
self.embed_dim = config.decoder_embed_dim
self.embed_scale = (
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
)
"""
if embed_tokens is None:
embed_tokens = torch.nn.Embedding(
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
)
"""
self.embed_tokens = None
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
self.output_projection = self.build_output_projection(config)
else:
self.output_projection = output_projection
if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layernorm_embedding = None
self.layers = torch.nn.ModuleList([])
for i in range(config.decoder_layers):
layer = self.build_decoder_layer(
config,
depth=i,
)
"""
if config.checkpoint_activations:
layer = checkpoint_wrapper(layer)
"""
self.layers.append(layer)
self.num_layers = len(self.layers)
if config.decoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(config)
self.chunkwise_recurrent = config.chunkwise_recurrent
self.recurrent_chunk_size = config.recurrent_chunk_size
if config.deepnorm:
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
if config.subln and not config.use_glu:
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
self.gradient_checkpointing = True
RetNetDecoder.__init__ = RetNetModel_init
# restores bias in our FFNs
def RetNetDecoderLayer_init(
self,
config,
depth,
use_bias=True
):
super(RetNetDecoderLayer, self).__init__()
self.config = config
self.embed_dim = config.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(config.dropout)
if config.drop_path_rate > 0:
drop_path_prob = np.linspace(
0, config.drop_path_rate, config.decoder_layers
)[depth]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.retention = MultiScaleRetention(
config,
self.embed_dim,
config.decoder_value_embed_dim,
config.decoder_retention_heads,
use_bias=use_bias
)
self.normalize_before = config.decoder_normalize_before
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
self.ffn_dim = config.decoder_ffn_embed_dim
self.ffn = self.build_ffn()
self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
if config.deepnorm:
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
else:
self.alpha = 1.0
def RetNetDecoderLayer_forward(
self,
x,
incremental_state=None,
chunkwise_recurrent=False,
retention_rel_pos=None,
):
residual = x
if self.normalize_before:
x = self.retention_layer_norm(x)
if x.requires_grad and self.config.checkpoint_activations:
x = checkpoint(
self.retention,
x,
use_reentrant=False,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
else:
x = self.retention(
x,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.retention_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.ffn(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init
RetNetDecoderLayer.forward = RetNetDecoderLayer_forward
# fixes backwards when using te's autocast
def MultiScaleRetention_init(
self,
config,
embed_dim,
value_dim,
num_heads,
gate_fn="swish",
use_bias=True,
):
super(MultiScaleRetention, self).__init__()
self.config = config
self.embed_dim = embed_dim
self.value_dim = value_dim
self.num_heads = num_heads
self.head_dim = self.value_dim // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim**-0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
self.v_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
self.g_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
self.out_proj = torch.nn.Linear(value_dim, embed_dim, bias=use_bias)
self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False)
self.reset_parameters()
def MultiScaleRetention_forward(
self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x) * self.scaling
v = self.v_proj(x)
g = self.g_proj(x)
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output
MultiScaleRetention.__init__ = MultiScaleRetention_init
MultiScaleRetention.forward = MultiScaleRetention_forward

View File

@ -203,7 +203,7 @@ with ui:
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
with gr.Row():
with gr.Column(scale=1):
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath", info="Reference audio for TTS")
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
# layout["inference"]["stop"] = gr.Button(value="Stop")
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")