cleanup
This commit is contained in:
parent
3fc0540f49
commit
2cef97e43f
|
@ -1,4 +0,0 @@
|
||||||
# from https://github.com/syncdoth/RetNet/
|
|
||||||
|
|
||||||
# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system
|
|
||||||
# this is included because torchscale's implementation recently changed and I don't want to keep maintaining a fork
|
|
|
@ -1,3 +0,0 @@
|
||||||
# from https://github.com/syncdoth/RetNet/
|
|
||||||
|
|
||||||
# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system
|
|
|
@ -1,117 +0,0 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
import json
|
|
||||||
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
def load_config_from_json(config_file):
|
|
||||||
with open(config_file, 'r') as f:
|
|
||||||
config = json.loads(f.read())
|
|
||||||
config = RetNetConfig.from_dict(config)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RetNetConfig(PretrainedConfig):
|
|
||||||
model_type = "retnet"
|
|
||||||
initializer_range: float = 0.02
|
|
||||||
activation_fn: str = "gelu"
|
|
||||||
dropout: float = 0.0 # dropout probability
|
|
||||||
activation_dropout: float = 0.0 # dropout probability after activation in FFN.
|
|
||||||
drop_path_rate: float = 0.0
|
|
||||||
decoder_embed_dim: int = 768 # decoder embedding dimension
|
|
||||||
decoder_value_embed_dim: int = 1280 # decoder value embedding dimension
|
|
||||||
decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN
|
|
||||||
decoder_layers: int = 12 # num decoder layers
|
|
||||||
decoder_retention_heads: int = 3 # num decoder retention heads
|
|
||||||
decoder_normalize_before: bool = True # apply layernorm before each decoder block
|
|
||||||
layernorm_embedding: bool = False # add layernorm to embedding
|
|
||||||
no_scale_embedding: bool = True # if True, dont scale embeddings
|
|
||||||
recurrent_chunk_size: int = 512
|
|
||||||
use_lm_decay: bool = False
|
|
||||||
use_glu: bool = True # use GLU instead of FFN
|
|
||||||
z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4
|
|
||||||
deepnorm: bool = False
|
|
||||||
subln: bool = True
|
|
||||||
use_ffn_rms_norm: bool = False
|
|
||||||
layernorm_eps: float = 1e-6
|
|
||||||
tie_word_embeddings: bool = False
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int = 50257,
|
|
||||||
initializer_range: float = 0.02,
|
|
||||||
is_decoder: bool = True,
|
|
||||||
pad_token_id: int = 0,
|
|
||||||
eos_token_id: int = 0,
|
|
||||||
output_retentions: bool = False,
|
|
||||||
use_cache: bool = True,
|
|
||||||
forward_impl: str = 'parallel',
|
|
||||||
activation_fn: str = "gelu",
|
|
||||||
dropout: float = 0.0, # dropout probability
|
|
||||||
activation_dropout: float = 0.0, # dropout probability after activation in FFN.
|
|
||||||
drop_path_rate: float = 0.0,
|
|
||||||
decoder_embed_dim: int = 768, # decoder embedding dimension
|
|
||||||
decoder_value_embed_dim: int = 1280, # decoder value embedding dimension
|
|
||||||
decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN
|
|
||||||
decoder_layers: int = 12, # num decoder layers
|
|
||||||
decoder_retention_heads: int = 3, # num decoder retention heads
|
|
||||||
decoder_normalize_before: bool = True, # apply layernorm before each decoder block
|
|
||||||
layernorm_embedding: bool = False, # add layernorm to embedding
|
|
||||||
no_scale_embedding: bool = True, # if True, dont scale embeddings
|
|
||||||
recurrent_chunk_size: int = 512,
|
|
||||||
use_glu: bool = True, # use GLU instead of FFN
|
|
||||||
z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
|
|
||||||
use_lm_decay: bool = False,
|
|
||||||
deepnorm: bool = True,
|
|
||||||
subln: bool = True,
|
|
||||||
use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN
|
|
||||||
layernorm_eps: float = 1e-6,
|
|
||||||
tie_word_embeddings: bool = False,
|
|
||||||
**kwargs):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.output_retentions = output_retentions
|
|
||||||
self.use_lm_decay = use_lm_decay
|
|
||||||
self.use_glu = use_glu
|
|
||||||
self.z_loss_coeff = z_loss_coeff
|
|
||||||
# size related
|
|
||||||
self.decoder_embed_dim = decoder_embed_dim
|
|
||||||
self.decoder_value_embed_dim = decoder_value_embed_dim
|
|
||||||
self.decoder_retention_heads = decoder_retention_heads
|
|
||||||
self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
|
|
||||||
self.decoder_layers = decoder_layers
|
|
||||||
# normalization related
|
|
||||||
self.decoder_normalize_before = decoder_normalize_before
|
|
||||||
self.activation_fn = activation_fn
|
|
||||||
self.dropout = dropout
|
|
||||||
self.drop_path_rate = drop_path_rate
|
|
||||||
self.activation_dropout = activation_dropout
|
|
||||||
self.no_scale_embedding = no_scale_embedding
|
|
||||||
self.layernorm_embedding = layernorm_embedding
|
|
||||||
self.deepnorm = deepnorm
|
|
||||||
self.subln = subln
|
|
||||||
self.use_ffn_rms_norm = use_ffn_rms_norm
|
|
||||||
self.layernorm_eps = layernorm_eps
|
|
||||||
# Blockwise
|
|
||||||
self.recurrent_chunk_size = recurrent_chunk_size
|
|
||||||
self.forward_impl = forward_impl
|
|
||||||
|
|
||||||
if self.deepnorm:
|
|
||||||
self.decoder_normalize_before = False
|
|
||||||
self.subln = False
|
|
||||||
if self.subln:
|
|
||||||
self.decoder_normalize_before = True
|
|
||||||
self.deepnorm = False
|
|
||||||
|
|
||||||
super().__init__(is_decoder=is_decoder,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
use_cache=use_cache,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def override(self, args):
|
|
||||||
for hp in self.__dict__.keys():
|
|
||||||
if getattr(args, hp, None) is not None:
|
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,74 +0,0 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetConfig(object):
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
|
||||||
self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
|
|
||||||
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
|
|
||||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
|
|
||||||
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
|
||||||
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
|
||||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
|
||||||
self.dropout = kwargs.pop("dropout", 0.0)
|
|
||||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
|
||||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
|
||||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
|
||||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
|
||||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
|
||||||
"moe_normalize_gate_prob_before_dropping", False)
|
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
|
||||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
|
||||||
self.subln = kwargs.pop("subln", True)
|
|
||||||
self.use_ffn_rms_norm = kwargs.pop("use_ffn_rms_norm", False)
|
|
||||||
self.use_glu = kwargs.pop("use_glu", True)
|
|
||||||
self.use_lm_decay = kwargs.pop("use_lm_decay", False)
|
|
||||||
self.z_loss_coeff = kwargs.pop("z_loss_coeff", 0.0) # TODO: 1e-4
|
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
|
||||||
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed",
|
|
||||||
False)
|
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", True)
|
|
||||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
|
|
||||||
# Blockwise
|
|
||||||
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
|
|
||||||
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
|
|
||||||
# Text
|
|
||||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
|
||||||
# Fairscale
|
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
|
||||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
|
||||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
|
||||||
# token id
|
|
||||||
self.pad_token_id = kwargs.pop("pad_token_id", 0)
|
|
||||||
self.postprocessing()
|
|
||||||
|
|
||||||
def postprocessing(self):
|
|
||||||
if self.deepnorm:
|
|
||||||
self.decoder_normalize_before = False
|
|
||||||
self.subln = False
|
|
||||||
if self.subln:
|
|
||||||
self.decoder_normalize_before = True
|
|
||||||
self.deepnorm = False
|
|
||||||
if self.use_xmoe:
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = True
|
|
||||||
self.moe_second_expert_policy = "random"
|
|
||||||
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
|
||||||
|
|
||||||
def override(self, args):
|
|
||||||
for hp in self.__dict__.keys():
|
|
||||||
if getattr(args, hp, None) is not None:
|
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
|
||||||
self.postprocessing()
|
|
|
@ -1,746 +0,0 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
# from fairscale.nn import checkpoint_wrapper, wrap
|
|
||||||
from timm.models.layers import drop_path
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
|
|
||||||
try:
|
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from torch.nn import LayerNorm
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_every_two(x):
|
|
||||||
x1 = x[:, :, :, ::2]
|
|
||||||
x2 = x[:, :, :, 1::2]
|
|
||||||
x = torch.stack((-x2, x1), dim=-1)
|
|
||||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
|
|
||||||
|
|
||||||
|
|
||||||
def theta_shift(x, sin, cos):
|
|
||||||
return (x * cos) + (rotate_every_two(x) * sin)
|
|
||||||
|
|
||||||
|
|
||||||
def get_activation_fn(activation):
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
elif activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
elif activation == "swish":
|
|
||||||
return F.silu
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.elementwise_affine = elementwise_affine
|
|
||||||
if self.elementwise_affine:
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
else:
|
|
||||||
self.register_parameter('weight', None)
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
output = self._norm(x.float()).type_as(x)
|
|
||||||
if self.weight is not None:
|
|
||||||
output = output * self.weight
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetRelPos(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
num_heads = config.decoder_retention_heads
|
|
||||||
|
|
||||||
angle = 1.0 / (10000**torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2))
|
|
||||||
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
|
|
||||||
if config.use_lm_decay:
|
|
||||||
# NOTE: alternative way described in the paper
|
|
||||||
s = torch.log(torch.tensor(1 / 32))
|
|
||||||
e = torch.log(torch.tensor(1 / 512))
|
|
||||||
decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,]
|
|
||||||
else:
|
|
||||||
decay = torch.log(1 - 2**(-5 - torch.arange(num_heads, dtype=torch.float)))
|
|
||||||
self.register_buffer("angle", angle)
|
|
||||||
self.register_buffer("decay", decay)
|
|
||||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
|
||||||
|
|
||||||
def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False):
|
|
||||||
if activate_recurrent:
|
|
||||||
sin = torch.sin(self.angle * (slen - 1))
|
|
||||||
cos = torch.cos(self.angle * (slen - 1))
|
|
||||||
retention_rel_pos = ((sin, cos), self.decay.exp())
|
|
||||||
elif chunkwise_recurrent:
|
|
||||||
index = torch.arange(slen).to(self.decay)
|
|
||||||
sin = torch.sin(index[:, None] * self.angle[None, :])
|
|
||||||
cos = torch.cos(index[:, None] * self.angle[None, :])
|
|
||||||
|
|
||||||
block_index = torch.arange(self.recurrent_chunk_size).to(self.decay)
|
|
||||||
mask = torch.tril(torch.ones(self.recurrent_chunk_size,
|
|
||||||
self.recurrent_chunk_size)).to(self.decay)
|
|
||||||
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(),
|
|
||||||
float("inf"))
|
|
||||||
mask = torch.exp(mask * self.decay[:, None, None])
|
|
||||||
mask = torch.nan_to_num(mask)
|
|
||||||
|
|
||||||
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
|
|
||||||
value_inner_decay = value_inner_decay.unsqueeze(-1)
|
|
||||||
scale = mask.sum(dim=-1, keepdim=True).sqrt()
|
|
||||||
inner_mask = mask / scale
|
|
||||||
|
|
||||||
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
|
|
||||||
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
|
||||||
query_inner_decay = query_inner_decay[:, :, None] / (
|
|
||||||
scale / mask[:, -1].sum(dim=-1)[:, None, None])
|
|
||||||
cross_decay = cross_decay[:, None, None]
|
|
||||||
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay,
|
|
||||||
value_inner_decay))
|
|
||||||
else:
|
|
||||||
index = torch.arange(slen).to(self.decay)
|
|
||||||
sin = torch.sin(index[:, None] * self.angle[None, :])
|
|
||||||
cos = torch.cos(index[:, None] * self.angle[None, :])
|
|
||||||
mask = torch.tril(torch.ones(slen, slen)).to(self.decay)
|
|
||||||
mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
|
|
||||||
mask = torch.exp(mask * self.decay[:, None, None])
|
|
||||||
mask = torch.nan_to_num(mask)
|
|
||||||
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
|
|
||||||
retention_rel_pos = ((sin, cos), mask)
|
|
||||||
|
|
||||||
return retention_rel_pos
|
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleRetention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
embed_dim,
|
|
||||||
value_dim,
|
|
||||||
num_heads,
|
|
||||||
gate_fn="swish",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.value_dim = value_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = self.value_dim // num_heads
|
|
||||||
self.key_dim = self.embed_dim // num_heads
|
|
||||||
self.scaling = self.key_dim**-0.5
|
|
||||||
|
|
||||||
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
||||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
||||||
self.v_proj = nn.Linear(embed_dim, value_dim, bias=False)
|
|
||||||
self.g_proj = nn.Linear(embed_dim, value_dim, bias=False)
|
|
||||||
|
|
||||||
self.out_proj = nn.Linear(value_dim, embed_dim, bias=False)
|
|
||||||
|
|
||||||
self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False)
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5)
|
|
||||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5)
|
|
||||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5)
|
|
||||||
nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5)
|
|
||||||
nn.init.xavier_uniform_(self.out_proj.weight, gain=2**-1)
|
|
||||||
|
|
||||||
def parallel_forward(self, qr, kr, v, mask):
|
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
|
||||||
|
|
||||||
vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
|
|
||||||
qk_mat = qk_mat * mask
|
|
||||||
# invariant after normalization
|
|
||||||
qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4)
|
|
||||||
output = torch.matmul(qk_mat, vr)
|
|
||||||
output = output.transpose(1, 2)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def recurrent_forward(self, qr, kr, v, decay, incremental_state):
|
|
||||||
bsz = v.size(0)
|
|
||||||
|
|
||||||
v = v.view(bsz, self.num_heads, self.head_dim, 1)
|
|
||||||
kv = kr * v
|
|
||||||
if "prev_key_value" in incremental_state:
|
|
||||||
prev_kv = incremental_state["prev_key_value"]
|
|
||||||
prev_scale = incremental_state["scale"]
|
|
||||||
scale = prev_scale * decay + 1
|
|
||||||
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(
|
|
||||||
self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
|
|
||||||
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
|
|
||||||
else:
|
|
||||||
scale = torch.ones_like(decay)
|
|
||||||
|
|
||||||
incremental_state["prev_key_value"] = kv
|
|
||||||
incremental_state["scale"] = scale
|
|
||||||
|
|
||||||
output = torch.sum(qr * kv, dim=3)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def chunk_recurrent_forward(self, qr, kr, v, inner_mask):
|
|
||||||
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
|
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
|
||||||
chunk_len = mask.size(1)
|
|
||||||
num_chunks = tgt_len // chunk_len
|
|
||||||
|
|
||||||
assert tgt_len % chunk_len == 0
|
|
||||||
|
|
||||||
qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
|
|
||||||
kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
|
|
||||||
v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
|
|
||||||
|
|
||||||
kr_t = kr.transpose(-1, -2)
|
|
||||||
|
|
||||||
qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
|
|
||||||
qk_mat = qk_mat * mask
|
|
||||||
inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
|
|
||||||
qk_mat = qk_mat / inner_scale
|
|
||||||
inner_output = torch.matmul(qk_mat,
|
|
||||||
v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
|
|
||||||
|
|
||||||
# reduce kv in one chunk
|
|
||||||
kv = kr_t @ (v * value_inner_decay)
|
|
||||||
|
|
||||||
kv_recurrent = []
|
|
||||||
cross_scale = []
|
|
||||||
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
|
|
||||||
kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
|
|
||||||
|
|
||||||
# accumulate kv by loop
|
|
||||||
for i in range(num_chunks):
|
|
||||||
kv_recurrent.append(kv_state / kv_scale)
|
|
||||||
cross_scale.append(kv_scale)
|
|
||||||
kv_state = kv_state * cross_decay + kv[:, i]
|
|
||||||
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(
|
|
||||||
dim=-1, keepdim=True).values.clamp(min=1)
|
|
||||||
|
|
||||||
kv_recurrent = torch.stack(kv_recurrent, dim=1)
|
|
||||||
cross_scale = torch.stack(cross_scale, dim=1)
|
|
||||||
|
|
||||||
all_scale = torch.maximum(inner_scale, cross_scale)
|
|
||||||
align_inner_scale = all_scale / inner_scale
|
|
||||||
align_cross_scale = all_scale / cross_scale
|
|
||||||
|
|
||||||
cross_output = (qr * query_inner_decay) @ kv_recurrent
|
|
||||||
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
|
||||||
# output = inner_output / cross_scale + cross_output / inner_scale
|
|
||||||
|
|
||||||
output = output.transpose(2, 3)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None):
|
|
||||||
bsz, tgt_len, _ = x.size()
|
|
||||||
(sin, cos), inner_mask = rel_pos
|
|
||||||
|
|
||||||
q = self.q_proj(x)
|
|
||||||
k = self.k_proj(x)
|
|
||||||
v = self.v_proj(x)
|
|
||||||
g = self.g_proj(x)
|
|
||||||
|
|
||||||
k *= self.scaling
|
|
||||||
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
||||||
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
qr = theta_shift(q, sin, cos)
|
|
||||||
kr = theta_shift(k, sin, cos)
|
|
||||||
|
|
||||||
if incremental_state is not None:
|
|
||||||
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
|
|
||||||
elif chunkwise_recurrent:
|
|
||||||
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
|
|
||||||
else:
|
|
||||||
output = self.parallel_forward(qr, kr, v, inner_mask)
|
|
||||||
|
|
||||||
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
|
|
||||||
|
|
||||||
output = self.gate_fn(g) * output
|
|
||||||
|
|
||||||
output = self.out_proj(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForwardNetwork(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim,
|
|
||||||
ffn_dim,
|
|
||||||
activation_fn,
|
|
||||||
dropout,
|
|
||||||
activation_dropout,
|
|
||||||
layernorm_eps,
|
|
||||||
subln=False,
|
|
||||||
use_rms_norm=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
|
||||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
|
||||||
self.dropout_module = torch.nn.Dropout(dropout)
|
|
||||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
|
||||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
|
||||||
if subln:
|
|
||||||
if use_rms_norm:
|
|
||||||
self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
|
|
||||||
else:
|
|
||||||
self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
|
|
||||||
else:
|
|
||||||
self.ffn_layernorm = None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
self.fc1.reset_parameters()
|
|
||||||
self.fc2.reset_parameters()
|
|
||||||
if self.ffn_layernorm is not None:
|
|
||||||
self.ffn_layernorm.reset_parameters()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_shape = x.shape
|
|
||||||
x = x.reshape(-1, x.size(-1))
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.activation_fn(x.float()).type_as(x)
|
|
||||||
x = self.activation_dropout_module(x)
|
|
||||||
if self.ffn_layernorm is not None:
|
|
||||||
x = self.ffn_layernorm(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = x.view(x_shape)
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class GLU(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim,
|
|
||||||
ffn_dim,
|
|
||||||
activation_fn,
|
|
||||||
dropout,
|
|
||||||
activation_dropout,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
|
||||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
|
||||||
self.dropout_module = torch.nn.Dropout(dropout)
|
|
||||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
|
||||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
|
||||||
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
self.fc1.reset_parameters()
|
|
||||||
self.fc2.reset_parameters()
|
|
||||||
self.gate.reset_parameters()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_shape = x.shape
|
|
||||||
x = x.reshape(-1, x.size(-1))
|
|
||||||
g = self.gate(x)
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.activation_fn(x.float()).type_as(x) * g
|
|
||||||
x = self.activation_dropout_module(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = x.view(x_shape)
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
|
||||||
|
|
||||||
def __init__(self, drop_prob=None):
|
|
||||||
super(DropPath, self).__init__()
|
|
||||||
self.drop_prob = drop_prob
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return "p={}".format(self.drop_prob)
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetDecoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
depth,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = config.decoder_embed_dim
|
|
||||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
||||||
|
|
||||||
if config.drop_path_rate > 0:
|
|
||||||
drop_path_prob = np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth]
|
|
||||||
self.drop_path = DropPath(drop_path_prob)
|
|
||||||
else:
|
|
||||||
self.drop_path = None
|
|
||||||
|
|
||||||
self.retention = MultiScaleRetention(
|
|
||||||
config,
|
|
||||||
self.embed_dim,
|
|
||||||
config.decoder_value_embed_dim,
|
|
||||||
config.decoder_retention_heads,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.normalize_before = config.decoder_normalize_before
|
|
||||||
|
|
||||||
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
|
|
||||||
|
|
||||||
self.ffn_dim = config.decoder_ffn_embed_dim
|
|
||||||
|
|
||||||
self.ffn = self.build_ffn()
|
|
||||||
|
|
||||||
self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
|
|
||||||
|
|
||||||
if config.deepnorm:
|
|
||||||
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
|
|
||||||
else:
|
|
||||||
self.alpha = 1.0
|
|
||||||
|
|
||||||
def build_ffn(self):
|
|
||||||
if self.config.use_glu:
|
|
||||||
return GLU(
|
|
||||||
self.embed_dim,
|
|
||||||
self.ffn_dim,
|
|
||||||
self.config.activation_fn,
|
|
||||||
self.config.dropout,
|
|
||||||
self.config.activation_dropout,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return FeedForwardNetwork(
|
|
||||||
self.embed_dim,
|
|
||||||
self.ffn_dim,
|
|
||||||
self.config.activation_fn,
|
|
||||||
self.config.dropout,
|
|
||||||
self.config.activation_dropout,
|
|
||||||
self.config.layernorm_eps,
|
|
||||||
self.config.subln,
|
|
||||||
self.config.use_ffn_rms_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
def residual_connection(self, x, residual):
|
|
||||||
return residual * self.alpha + x
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
incremental_state=None,
|
|
||||||
chunkwise_recurrent=False,
|
|
||||||
retention_rel_pos=None,
|
|
||||||
):
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.retention_layer_norm(x)
|
|
||||||
|
|
||||||
x = self.retention(
|
|
||||||
x,
|
|
||||||
incremental_state=incremental_state,
|
|
||||||
rel_pos=retention_rel_pos,
|
|
||||||
chunkwise_recurrent=chunkwise_recurrent,
|
|
||||||
)
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
|
||||||
x = self.drop_path(x)
|
|
||||||
|
|
||||||
x = self.residual_connection(x, residual)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.retention_layer_norm(x)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
|
|
||||||
x = self.ffn(x)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
|
||||||
x = self.drop_path(x)
|
|
||||||
|
|
||||||
x = self.residual_connection(x, residual)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
||||||
|
|
||||||
embed_dim = config.decoder_embed_dim
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.embed_scale = 1.0 if config.no_scale_embedding else math.sqrt(embed_dim)
|
|
||||||
|
|
||||||
self.embed_tokens = embed_tokens
|
|
||||||
|
|
||||||
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
|
|
||||||
self.output_projection = self.build_output_projection(config)
|
|
||||||
else:
|
|
||||||
self.output_projection = output_projection
|
|
||||||
|
|
||||||
if config.layernorm_embedding:
|
|
||||||
self.layernorm_embedding = RMSNorm(embed_dim, eps=config.layernorm_eps)
|
|
||||||
else:
|
|
||||||
self.layernorm_embedding = None
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
|
|
||||||
for i in range(config.decoder_layers):
|
|
||||||
self.layers.append(self.build_decoder_layer(
|
|
||||||
config,
|
|
||||||
depth=i,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.num_layers = len(self.layers)
|
|
||||||
|
|
||||||
if config.decoder_normalize_before:
|
|
||||||
self.layer_norm = RMSNorm(embed_dim, eps=config.layernorm_eps)
|
|
||||||
else:
|
|
||||||
self.layer_norm = None
|
|
||||||
|
|
||||||
self.retnet_rel_pos = RetNetRelPos(config)
|
|
||||||
self.chunkwise_recurrent = config.chunkwise_recurrent
|
|
||||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
|
||||||
|
|
||||||
if config.deepnorm:
|
|
||||||
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name):
|
|
||||||
p.data.div_(init_scale)
|
|
||||||
|
|
||||||
if config.subln and not config.use_glu:
|
|
||||||
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name):
|
|
||||||
p.data.mul_(init_scale)
|
|
||||||
|
|
||||||
def build_output_projection(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
):
|
|
||||||
if config.share_decoder_input_output_embed:
|
|
||||||
output_projection = torch.nn.Linear(
|
|
||||||
self.embed_tokens.weight.shape[1],
|
|
||||||
self.embed_tokens.weight.shape[0],
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
output_projection.weight = self.embed_tokens.weight
|
|
||||||
else:
|
|
||||||
output_projection = torch.nn.Linear(config.decoder_embed_dim,
|
|
||||||
config.vocab_size,
|
|
||||||
bias=False)
|
|
||||||
torch.nn.init.normal_(output_projection.weight,
|
|
||||||
mean=0,
|
|
||||||
std=config.decoder_embed_dim**-0.5)
|
|
||||||
return output_projection
|
|
||||||
|
|
||||||
def build_decoder_layer(self, config, depth):
|
|
||||||
layer = RetNetDecoderLayer(
|
|
||||||
config,
|
|
||||||
depth,
|
|
||||||
)
|
|
||||||
# if config.checkpoint_activations:
|
|
||||||
# layer = checkpoint_wrapper(layer)
|
|
||||||
# if config.fsdp:
|
|
||||||
# layer = wrap(layer)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
def forward_embedding(
|
|
||||||
self,
|
|
||||||
tokens,
|
|
||||||
token_embedding=None,
|
|
||||||
incremental_state=None,
|
|
||||||
):
|
|
||||||
if incremental_state is not None and not self.is_first_step(incremental_state):
|
|
||||||
tokens = tokens[:, -1:]
|
|
||||||
|
|
||||||
if token_embedding is None:
|
|
||||||
token_embedding = self.embed_tokens(tokens)
|
|
||||||
|
|
||||||
x = embed = self.embed_scale * token_embedding
|
|
||||||
|
|
||||||
if self.layernorm_embedding is not None:
|
|
||||||
x = self.layernorm_embedding(x)
|
|
||||||
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
|
|
||||||
return x, embed
|
|
||||||
|
|
||||||
def is_first_step(self, incremental_state):
|
|
||||||
if incremental_state is None:
|
|
||||||
return False
|
|
||||||
return incremental_state.get("is_first_step", False)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
prev_output_tokens,
|
|
||||||
incremental_state=None,
|
|
||||||
features_only=False,
|
|
||||||
token_embeddings=None):
|
|
||||||
# embed tokens
|
|
||||||
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
|
|
||||||
is_first_step = self.is_first_step(incremental_state)
|
|
||||||
|
|
||||||
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
|
|
||||||
padding_len = self.recurrent_chunk_size - prev_output_tokens.size(
|
|
||||||
1) % self.recurrent_chunk_size
|
|
||||||
slen = prev_output_tokens.size(1) + padding_len
|
|
||||||
x = F.pad(x, (0, 0, 0, padding_len))
|
|
||||||
else:
|
|
||||||
slen = prev_output_tokens.size(1)
|
|
||||||
# relative position
|
|
||||||
retention_rel_pos = self.retnet_rel_pos(slen,
|
|
||||||
incremental_state is not None and not is_first_step,
|
|
||||||
chunkwise_recurrent=self.chunkwise_recurrent)
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
inner_states = [x]
|
|
||||||
|
|
||||||
for idx, layer in enumerate(self.layers):
|
|
||||||
if incremental_state is None or is_first_step:
|
|
||||||
if is_first_step and incremental_state is not None:
|
|
||||||
if idx not in incremental_state:
|
|
||||||
incremental_state[idx] = {}
|
|
||||||
else:
|
|
||||||
if idx not in incremental_state:
|
|
||||||
incremental_state[idx] = {}
|
|
||||||
|
|
||||||
x = layer(
|
|
||||||
x,
|
|
||||||
incremental_state[idx] if incremental_state is not None else None,
|
|
||||||
retention_rel_pos=retention_rel_pos,
|
|
||||||
chunkwise_recurrent=self.chunkwise_recurrent,
|
|
||||||
)
|
|
||||||
inner_states.append(x)
|
|
||||||
|
|
||||||
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
|
|
||||||
x = x[:, :prev_output_tokens.size(1), :]
|
|
||||||
|
|
||||||
if self.layer_norm is not None:
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
|
|
||||||
if not features_only:
|
|
||||||
x = self.output_layer(x)
|
|
||||||
|
|
||||||
return x, {
|
|
||||||
"inner_states": inner_states,
|
|
||||||
"attn": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
def output_layer(self, features):
|
|
||||||
return self.output_projection(features)
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetForCausalLM(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config, embed_tokens=None, output_projection=None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
assert config.vocab_size > 0, "you must specify vocab size"
|
|
||||||
if output_projection is None:
|
|
||||||
config.no_output_layer = False
|
|
||||||
if embed_tokens is None:
|
|
||||||
embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim,
|
|
||||||
config.pad_token_id)
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.model = RetNetModel(config,
|
|
||||||
embed_tokens=embed_tokens,
|
|
||||||
output_projection=output_projection,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.model.embed_tokens
|
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
|
||||||
self.model.embed_tokens = value
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
return self.model.output_projection
|
|
||||||
|
|
||||||
def set_output_embeddings(self, new_embeddings):
|
|
||||||
self.model.output_projection = new_embeddings
|
|
||||||
|
|
||||||
def set_decoder(self, decoder):
|
|
||||||
self.model = decoder
|
|
||||||
|
|
||||||
def get_decoder(self):
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
retention_mask: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_retentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
recurrent_chunk_size: Optional[int] = None,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids,
|
|
||||||
incremental_state=past_key_values,
|
|
||||||
features_only=False,
|
|
||||||
token_embeddings=inputs_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits, inner_hidden_states = outputs[0], outputs[1]['inner_states']
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
if self.config.z_loss_coeff > 0:
|
|
||||||
# z_loss from PaLM paper
|
|
||||||
# z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits))
|
|
||||||
z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean()
|
|
||||||
loss += self.config.z_loss_coeff * z_loss
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
hidden_states=inner_hidden_states,
|
|
||||||
attentions=None,
|
|
||||||
)
|
|
|
@ -15,6 +15,7 @@ except Exception as e:
|
||||||
ERROR_ARCHES["retnet"] = e
|
ERROR_ARCHES["retnet"] = e
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
|
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
|
||||||
AVAILABLE_ARCHES.append("retnet-ts")
|
AVAILABLE_ARCHES.append("retnet-ts")
|
||||||
|
@ -28,6 +29,7 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ERROR_ARCHES["retnet-hf"] = e
|
ERROR_ARCHES["retnet-hf"] = e
|
||||||
pass
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
|
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
|
||||||
|
@ -50,6 +52,15 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ERROR_ARCHES["mixtral"] = e
|
ERROR_ARCHES["mixtral"] = e
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config
|
||||||
|
AVAILABLE_ARCHES.append("mamba")
|
||||||
|
AVAILABLE_ARCHES.append("mamba2")
|
||||||
|
except Exception as e:
|
||||||
|
ERROR_ARCHES["mamba"] = e
|
||||||
|
ERROR_ARCHES["mamba2"] = e
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
|
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
|
||||||
AVAILABLE_ARCHES.append("mamba")
|
AVAILABLE_ARCHES.append("mamba")
|
||||||
|
@ -62,4 +73,5 @@ try:
|
||||||
from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF
|
from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF
|
||||||
AVAILABLE_ARCHES.append("mamba2-hf")
|
AVAILABLE_ARCHES.append("mamba2-hf")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ERROR_ARCHES["mamba2-hf"] = e
|
ERROR_ARCHES["mamba2-hf"] = e
|
||||||
|
"""
|
|
@ -1,3 +1,11 @@
|
||||||
|
|
||||||
|
from transformers.models.mamba.modeling_mamba import MambaModel
|
||||||
|
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
|
||||||
|
|
||||||
|
from transformers.models.mamba.configuration_mamba import MambaConfig
|
||||||
|
from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
|
||||||
|
|
||||||
|
"""
|
||||||
# https://github.com/state-spaces/mamba
|
# https://github.com/state-spaces/mamba
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
@ -29,4 +37,5 @@ def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
MambaMixelModel.forward = MambaMixelModel_forward
|
MambaMixelModel.forward = MambaMixelModel_forward
|
||||||
|
"""
|
|
@ -1 +0,0 @@
|
||||||
from .mamba2_hf import *
|
|
|
@ -1,4 +0,0 @@
|
||||||
# https://github.com/vasqu/mamba2-torch
|
|
||||||
# NOTE: edit `src/mamba2_torch/__init__.py` to remove reference to .src. because of how pip treats packages
|
|
||||||
|
|
||||||
from mamba2_torch import Mamba2Model as Mamba2Model_HF, Mamba2Config as Mamba2Config_HF
|
|
|
@ -1,196 +0,0 @@
|
||||||
# https://github.com/syncdoth/RetNet/
|
|
||||||
from ....ext.retnet_hf.configuration_retnet import RetNetConfig
|
|
||||||
from ....ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM
|
|
||||||
|
|
||||||
# things we're overriding or required to override
|
|
||||||
from ....ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
# required to have compatibile LayerNorm
|
|
||||||
def FeedForwardNetwork_init(
|
|
||||||
self,
|
|
||||||
embed_dim,
|
|
||||||
ffn_dim,
|
|
||||||
activation_fn,
|
|
||||||
dropout,
|
|
||||||
activation_dropout,
|
|
||||||
layernorm_eps,
|
|
||||||
subln=True,
|
|
||||||
use_rms_norm=False,
|
|
||||||
):
|
|
||||||
super(FeedForwardNetwork, self).__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
|
||||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
|
||||||
self.dropout_module = torch.nn.Dropout(dropout)
|
|
||||||
self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim)
|
|
||||||
self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim)
|
|
||||||
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
|
|
||||||
|
|
||||||
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
|
||||||
|
|
||||||
def RetNetModel_init(
|
|
||||||
self,
|
|
||||||
config: RetNetConfig,
|
|
||||||
embed_tokens: torch.nn.Embedding = None,
|
|
||||||
tensor_parallel: bool = False,
|
|
||||||
):
|
|
||||||
super(RetNetDecoder, self).__init__(config)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
||||||
|
|
||||||
self.embed_dim = config.decoder_embed_dim
|
|
||||||
self.embed_scale = (
|
|
||||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
if embed_tokens is None and config.vocab_size:
|
|
||||||
embed_tokens = torch.nn.Embedding(
|
|
||||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
|
||||||
)
|
|
||||||
self.embed_tokens = embed_tokens
|
|
||||||
|
|
||||||
if config.layernorm_embedding:
|
|
||||||
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
else:
|
|
||||||
self.layernorm_embedding = None
|
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList([])
|
|
||||||
|
|
||||||
for i in range(config.decoder_layers):
|
|
||||||
self.layers.append(
|
|
||||||
RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decoder_layers = len(self.layers)
|
|
||||||
|
|
||||||
if config.decoder_normalize_before:
|
|
||||||
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
else:
|
|
||||||
self.layer_norm = None
|
|
||||||
|
|
||||||
self.retnet_rel_pos = RetNetRelPos(config)
|
|
||||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
|
||||||
|
|
||||||
if config.deepnorm:
|
|
||||||
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if (
|
|
||||||
"fc1" in name
|
|
||||||
or "fc2" in name
|
|
||||||
or "out_proj" in name
|
|
||||||
or "v_proj" in name
|
|
||||||
):
|
|
||||||
p.data.div_(init_scale)
|
|
||||||
|
|
||||||
if config.subln and not config.use_glu:
|
|
||||||
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if (
|
|
||||||
"fc1" in name
|
|
||||||
or "fc2" in name
|
|
||||||
or "out_proj" in name
|
|
||||||
or "v_proj" in name
|
|
||||||
):
|
|
||||||
p.data.mul_(init_scale)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = True
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
RetNetDecoder.__init__ = RetNetModel_init
|
|
||||||
|
|
||||||
# restores bias in our FFNs
|
|
||||||
def RetNetDecoderLayer_init(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False):
|
|
||||||
super(RetNetDecoderLayer, self).__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = config.decoder_embed_dim
|
|
||||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
||||||
|
|
||||||
if config.drop_path_rate > 0:
|
|
||||||
drop_path_prob = np.linspace(
|
|
||||||
0, config.drop_path_rate, config.decoder_layers
|
|
||||||
)[depth]
|
|
||||||
self.drop_path = DropPath(drop_path_prob)
|
|
||||||
else:
|
|
||||||
self.drop_path = None
|
|
||||||
|
|
||||||
self.retention = MultiScaleRetention(
|
|
||||||
config, use_bias=True, tensor_parallel=tensor_parallel
|
|
||||||
)
|
|
||||||
|
|
||||||
self.normalize_before = config.decoder_normalize_before
|
|
||||||
|
|
||||||
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
|
|
||||||
self.ffn_dim = config.decoder_ffn_embed_dim
|
|
||||||
|
|
||||||
self.ffn = self.build_ffn()
|
|
||||||
|
|
||||||
self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
|
|
||||||
if config.deepnorm:
|
|
||||||
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
|
|
||||||
else:
|
|
||||||
self.alpha = 1.0
|
|
||||||
|
|
||||||
RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init
|
|
||||||
# fixes backwards when using te's autocast
|
|
||||||
def MultiScaleRetention_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
rel_pos: Tuple[Tuple[torch.Tensor]],
|
|
||||||
retention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
forward_impl: str = "parallel",
|
|
||||||
output_retentions: Optional[bool] = False,
|
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
|
|
||||||
B, T, H = hidden_states.size()
|
|
||||||
(sin, cos), decay_mask = rel_pos
|
|
||||||
# projections
|
|
||||||
q = self.q_proj(hidden_states)
|
|
||||||
k = self.k_proj(hidden_states) * self.scaling # for scaled dot product
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
g = self.g_proj(hidden_states)
|
|
||||||
# multi-head
|
|
||||||
q, k, v = split_heads((q, k, v), B, T, self.num_heads)
|
|
||||||
|
|
||||||
# rotate
|
|
||||||
# NOTE: theta_shift has bug with mps device.
|
|
||||||
qr = theta_shift(q, sin, cos)
|
|
||||||
kr = theta_shift(k, sin, cos)
|
|
||||||
|
|
||||||
# retention
|
|
||||||
if forward_impl == "parallel":
|
|
||||||
retention_out, curr_kv, retention_weights = self.parallel_retention(
|
|
||||||
qr, kr, v, decay_mask
|
|
||||||
)
|
|
||||||
elif forward_impl == "recurrent":
|
|
||||||
retention_out, curr_kv = self.recurrent_retention(
|
|
||||||
qr,
|
|
||||||
kr,
|
|
||||||
v,
|
|
||||||
decay_mask,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
retention_mask=retention_mask,
|
|
||||||
)
|
|
||||||
elif forward_impl == "chunkwise":
|
|
||||||
retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"forward_impl {forward_impl} not supported.")
|
|
||||||
|
|
||||||
# concaat heads
|
|
||||||
normed = self.group_norm(retention_out).reshape(B, T, self.value_dim)
|
|
||||||
# out gate & proj
|
|
||||||
out = self.gate_fn(g) * normed
|
|
||||||
out = self.out_proj(out)
|
|
||||||
|
|
||||||
outputs = (out, curr_kv)
|
|
||||||
if output_retentions:
|
|
||||||
outputs += (retention_weights,) if forward_impl == "parallel" else (None,)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
MultiScaleRetention.forward = MultiScaleRetention_forward
|
|
|
@ -1,277 +0,0 @@
|
||||||
# https://github.com/syncdoth/RetNet/
|
|
||||||
from ....ext.retnet_ts.config import RetNetConfig
|
|
||||||
from ....ext.retnet_ts.retnet import RetNetModel as RetNetDecoder
|
|
||||||
|
|
||||||
# things we're overriding or required to override
|
|
||||||
from ....ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
# required to have compatibile LayerNorm
|
|
||||||
def FeedForwardNetwork_init(
|
|
||||||
self,
|
|
||||||
embed_dim,
|
|
||||||
ffn_dim,
|
|
||||||
activation_fn,
|
|
||||||
dropout,
|
|
||||||
activation_dropout,
|
|
||||||
layernorm_eps,
|
|
||||||
subln=True,
|
|
||||||
use_rms_norm=False,
|
|
||||||
):
|
|
||||||
super(FeedForwardNetwork, self).__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
|
||||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
|
||||||
self.dropout_module = torch.nn.Dropout(dropout)
|
|
||||||
self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim)
|
|
||||||
self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim)
|
|
||||||
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
|
|
||||||
|
|
||||||
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
|
||||||
|
|
||||||
# removes embed_tokens
|
|
||||||
def RetNetModel_init(
|
|
||||||
self, config, embed_tokens=None, output_projection=None, **kwargs
|
|
||||||
):
|
|
||||||
super(RetNetDecoder, self).__init__(**kwargs)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
||||||
|
|
||||||
self.embed_dim = config.decoder_embed_dim
|
|
||||||
self.embed_scale = (
|
|
||||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
if embed_tokens is None and config.vocab_size:
|
|
||||||
embed_tokens = torch.nn.Embedding(
|
|
||||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
|
||||||
)
|
|
||||||
self.embed_tokens = embed_tokens
|
|
||||||
|
|
||||||
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
|
|
||||||
self.output_projection = self.build_output_projection(config)
|
|
||||||
else:
|
|
||||||
self.output_projection = output_projection
|
|
||||||
|
|
||||||
if config.layernorm_embedding:
|
|
||||||
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
else:
|
|
||||||
self.layernorm_embedding = None
|
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList([])
|
|
||||||
|
|
||||||
for i in range(config.decoder_layers):
|
|
||||||
layer = self.build_decoder_layer(
|
|
||||||
config,
|
|
||||||
depth=i,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
if config.checkpoint_activations:
|
|
||||||
layer = checkpoint_wrapper(layer)
|
|
||||||
"""
|
|
||||||
self.layers.append(layer)
|
|
||||||
|
|
||||||
self.num_layers = len(self.layers)
|
|
||||||
|
|
||||||
if config.decoder_normalize_before:
|
|
||||||
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
else:
|
|
||||||
self.layer_norm = None
|
|
||||||
|
|
||||||
self.retnet_rel_pos = RetNetRelPos(config)
|
|
||||||
self.chunkwise_recurrent = config.chunkwise_recurrent
|
|
||||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
|
||||||
|
|
||||||
if config.deepnorm:
|
|
||||||
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if (
|
|
||||||
"fc1" in name
|
|
||||||
or "fc2" in name
|
|
||||||
or "out_proj" in name
|
|
||||||
or "v_proj" in name
|
|
||||||
):
|
|
||||||
p.data.div_(init_scale)
|
|
||||||
|
|
||||||
if config.subln and not config.use_glu:
|
|
||||||
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if (
|
|
||||||
"fc1" in name
|
|
||||||
or "fc2" in name
|
|
||||||
or "out_proj" in name
|
|
||||||
or "v_proj" in name
|
|
||||||
):
|
|
||||||
p.data.mul_(init_scale)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = True
|
|
||||||
|
|
||||||
RetNetDecoder.__init__ = RetNetModel_init
|
|
||||||
|
|
||||||
# restores bias in our FFNs
|
|
||||||
def RetNetDecoderLayer_init(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
depth,
|
|
||||||
use_bias=True
|
|
||||||
):
|
|
||||||
super(RetNetDecoderLayer, self).__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = config.decoder_embed_dim
|
|
||||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
||||||
|
|
||||||
if config.drop_path_rate > 0:
|
|
||||||
drop_path_prob = np.linspace(
|
|
||||||
0, config.drop_path_rate, config.decoder_layers
|
|
||||||
)[depth]
|
|
||||||
self.drop_path = DropPath(drop_path_prob)
|
|
||||||
else:
|
|
||||||
self.drop_path = None
|
|
||||||
|
|
||||||
self.retention = MultiScaleRetention(
|
|
||||||
config,
|
|
||||||
self.embed_dim,
|
|
||||||
config.decoder_value_embed_dim,
|
|
||||||
config.decoder_retention_heads,
|
|
||||||
use_bias=use_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
self.normalize_before = config.decoder_normalize_before
|
|
||||||
|
|
||||||
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
|
|
||||||
self.ffn_dim = config.decoder_ffn_embed_dim
|
|
||||||
|
|
||||||
self.ffn = self.build_ffn()
|
|
||||||
|
|
||||||
self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
||||||
|
|
||||||
if config.deepnorm:
|
|
||||||
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
|
|
||||||
else:
|
|
||||||
self.alpha = 1.0
|
|
||||||
|
|
||||||
def RetNetDecoderLayer_forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
incremental_state=None,
|
|
||||||
chunkwise_recurrent=False,
|
|
||||||
retention_rel_pos=None,
|
|
||||||
):
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.retention_layer_norm(x)
|
|
||||||
|
|
||||||
if x.requires_grad and self.config.checkpoint_activations:
|
|
||||||
x = checkpoint(
|
|
||||||
self.retention,
|
|
||||||
x,
|
|
||||||
use_reentrant=False,
|
|
||||||
incremental_state=incremental_state,
|
|
||||||
rel_pos=retention_rel_pos,
|
|
||||||
chunkwise_recurrent=chunkwise_recurrent,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = self.retention(
|
|
||||||
x,
|
|
||||||
incremental_state=incremental_state,
|
|
||||||
rel_pos=retention_rel_pos,
|
|
||||||
chunkwise_recurrent=chunkwise_recurrent,
|
|
||||||
)
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
|
||||||
x = self.drop_path(x)
|
|
||||||
|
|
||||||
x = self.residual_connection(x, residual)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.retention_layer_norm(x)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
|
|
||||||
x = self.ffn(x)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
|
||||||
x = self.drop_path(x)
|
|
||||||
|
|
||||||
x = self.residual_connection(x, residual)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init
|
|
||||||
RetNetDecoderLayer.forward = RetNetDecoderLayer_forward
|
|
||||||
# fixes backwards when using te's autocast
|
|
||||||
def MultiScaleRetention_init(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
embed_dim,
|
|
||||||
value_dim,
|
|
||||||
num_heads,
|
|
||||||
gate_fn="swish",
|
|
||||||
use_bias=True,
|
|
||||||
):
|
|
||||||
super(MultiScaleRetention, self).__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.value_dim = value_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = self.value_dim // num_heads
|
|
||||||
self.key_dim = self.embed_dim // num_heads
|
|
||||||
self.scaling = self.key_dim**-0.5
|
|
||||||
|
|
||||||
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
|
||||||
|
|
||||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
|
|
||||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
|
|
||||||
self.v_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
|
|
||||||
self.g_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
|
|
||||||
|
|
||||||
self.out_proj = torch.nn.Linear(value_dim, embed_dim, bias=use_bias)
|
|
||||||
|
|
||||||
self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False)
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def MultiScaleRetention_forward(
|
|
||||||
self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None
|
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
|
|
||||||
bsz, tgt_len, _ = x.size()
|
|
||||||
(sin, cos), inner_mask = rel_pos
|
|
||||||
|
|
||||||
q = self.q_proj(x)
|
|
||||||
k = self.k_proj(x) * self.scaling
|
|
||||||
v = self.v_proj(x)
|
|
||||||
g = self.g_proj(x)
|
|
||||||
|
|
||||||
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
||||||
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
qr = theta_shift(q, sin, cos)
|
|
||||||
kr = theta_shift(k, sin, cos)
|
|
||||||
|
|
||||||
if incremental_state is not None:
|
|
||||||
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
|
|
||||||
elif chunkwise_recurrent:
|
|
||||||
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
|
|
||||||
else:
|
|
||||||
output = self.parallel_forward(qr, kr, v, inner_mask)
|
|
||||||
|
|
||||||
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
|
|
||||||
|
|
||||||
output = self.gate_fn(g) * output
|
|
||||||
|
|
||||||
output = self.out_proj(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
MultiScaleRetention.__init__ = MultiScaleRetention_init
|
|
||||||
MultiScaleRetention.forward = MultiScaleRetention_forward
|
|
|
@ -704,87 +704,26 @@ class Base(nn.Module):
|
||||||
))
|
))
|
||||||
|
|
||||||
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||||
elif self.arch_type == "retnet-hf":
|
elif self.arch_type in ["mamba2"]:
|
||||||
kwargs = dict(
|
self.model = Mamba2Model(Mamba2Config(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
decoder_embed_dim=d_model,
|
hidden_size=d_model,
|
||||||
decoder_value_embed_dim =d_model * 2,
|
expand=2,
|
||||||
decoder_retention_heads=n_heads,
|
num_hidden_layers=n_layers*2,
|
||||||
decoder_ffn_embed_dim=d_model * 4,
|
|
||||||
decoder_layers=n_layers,
|
|
||||||
dropout=p_dropout if training else 0.0,
|
|
||||||
checkpoint_activations=self.gradient_checkpointing,
|
|
||||||
activation_fn="gelu",
|
|
||||||
use_glu=False, # self.version >= 3,
|
|
||||||
|
|
||||||
recurrent_chunk_size=self.causal_size if self.causal else 0,
|
|
||||||
decoder_normalize_before=True,
|
|
||||||
|
|
||||||
deepnorm=False,
|
|
||||||
subln=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs))
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
|
||||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
|
||||||
use_reentrant=False
|
|
||||||
))
|
|
||||||
elif self.arch_type == "bitnet":
|
|
||||||
self.model = BitNetTransformer(
|
|
||||||
num_tokens=n_resp_tokens,
|
|
||||||
dim=d_model,
|
|
||||||
depth=n_layers,
|
|
||||||
heads=n_heads,
|
|
||||||
ff_mult=4,
|
|
||||||
gradient_checkpointing=self.gradient_checkpointing,
|
|
||||||
)
|
|
||||||
elif self.arch_type in ["mamba","mamba2"]:
|
|
||||||
self.model = MambaMixelModel(
|
|
||||||
vocab_size=n_resp_tokens,
|
|
||||||
d_model=d_model,
|
|
||||||
n_layer=n_layers*2,
|
|
||||||
d_intermediate=0, #d_model*2,
|
|
||||||
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if self.arch_type == "mamba2" else {},
|
|
||||||
rms_norm=True,
|
|
||||||
fused_add_norm=True,
|
|
||||||
residual_in_fp32=True,
|
residual_in_fp32=True,
|
||||||
#attn_layer_idx=attn_layer_idx,
|
|
||||||
#attn_cfg=attn_cfg,
|
|
||||||
#initializer_cfg=initializer_cfg,
|
|
||||||
)
|
|
||||||
self.model.gradient_checkpointing = self.gradient_checkpointing
|
|
||||||
elif self.arch_type in ["mamba2-hf"]:
|
|
||||||
self.model = Mamba2Model_HF(Mamba2Config_HF(
|
|
||||||
vocab_size=n_resp_tokens,
|
|
||||||
hidden_size=d_model,
|
|
||||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
|
||||||
expand=4,
|
|
||||||
num_hidden_layers=n_layers,
|
|
||||||
is_encoder_decoder=False,
|
|
||||||
is_decoder=True,
|
|
||||||
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
|
||||||
residual_in_fp32=True, # breaks for AMP inference
|
|
||||||
))
|
))
|
||||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
use_reentrant=False
|
use_reentrant=False
|
||||||
))
|
))
|
||||||
elif self.arch_type == "mmfreelm":
|
elif self.arch_type in ["mamba"]:
|
||||||
self.model = HGRNBitModel(HGRNBitConfig(
|
self.model = MambaModel(MambaConfig(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
expand=2,
|
||||||
intermediate_size=d_model*4,
|
num_hidden_layers=n_layers*2,
|
||||||
num_hidden_layers=n_layers,
|
residual_in_fp32=True,
|
||||||
num_heads=n_heads,
|
|
||||||
#hidden_act="gelu",
|
|
||||||
#is_encoder_decoder=False,
|
|
||||||
#is_decoder=True,
|
|
||||||
attn_mode=hf_attention,
|
|
||||||
#gradient_checkpointing=self.gradient_checkpointing,
|
|
||||||
))
|
))
|
||||||
|
|
||||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
use_reentrant=False
|
use_reentrant=False
|
||||||
|
@ -795,7 +734,6 @@ class Base(nn.Module):
|
||||||
if hasattr( self.model, "embeddings" ):
|
if hasattr( self.model, "embeddings" ):
|
||||||
del self.model.embeddings
|
del self.model.embeddings
|
||||||
|
|
||||||
|
|
||||||
if not split_classifiers:
|
if not split_classifiers:
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
self.classifiers = None
|
self.classifiers = None
|
||||||
|
@ -854,8 +792,8 @@ class Base(nn.Module):
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
#attention_mask=m,
|
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
|
attention_mask=m,
|
||||||
past_key_values=state,
|
past_key_values=state,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
use_cache=False, # not self.training,
|
use_cache=False, # not self.training,
|
||||||
|
@ -901,46 +839,31 @@ class Base(nn.Module):
|
||||||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||||
elif self.arch_type == "retnet-hf":
|
|
||||||
first = state is None or len(state) == 0
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
attention_mask=m,
|
|
||||||
inputs_embeds=x if first else x[:, -1, :].unsqueeze(1),
|
|
||||||
past_key_values=None if first else state,
|
|
||||||
use_cache=True,
|
|
||||||
forward_impl='parallel' if first else 'recurrent',
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.model(**kwargs)
|
|
||||||
x = out.last_hidden_state
|
|
||||||
if state is not None:
|
|
||||||
state = out.past_key_values
|
|
||||||
elif self.arch_type in ["mamba","mamba2"]:
|
elif self.arch_type in ["mamba","mamba2"]:
|
||||||
x = self.model( hidden_states=x )
|
|
||||||
elif self.arch_type == "mamba2-hf":
|
|
||||||
first = state is None or len(state) == 0
|
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
|
#attention_mask=m,
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
cache_params=state,
|
#cache_params=state,
|
||||||
|
use_cache=False, # not self.training,
|
||||||
|
#position_ids=position_ids,
|
||||||
|
#output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = self.model(**kwargs)
|
output = self.model(**kwargs)
|
||||||
x = out.last_hidden_state
|
x = output["last_hidden_state"]
|
||||||
|
|
||||||
|
# to-do: figure out why KV caching doesn't work
|
||||||
|
#if not self.training:
|
||||||
if state is not None:
|
if state is not None:
|
||||||
state = out.cache_params
|
state = output["cache_params"]
|
||||||
elif self.arch_type == "bitnet":
|
|
||||||
x = self.model(x)
|
|
||||||
elif self.arch_type == "mmfreelm":
|
|
||||||
x = self.model(
|
|
||||||
attention_mask=m,
|
|
||||||
inputs_embeds=x,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = x[0]
|
if output_attentions:
|
||||||
|
attentions = output["attentions"]
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
hidden_states = output["hidden_states"]
|
||||||
|
|
||||||
# process it into a format that I like
|
# process it into a format that I like
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -1559,7 +1482,6 @@ class Base(nn.Module):
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
|
|
||||||
x, mask = list_to_tensor(x_list)
|
x, mask = list_to_tensor(x_list)
|
||||||
m = mask.unsqueeze(dim=-1)
|
|
||||||
|
|
||||||
training = self.training
|
training = self.training
|
||||||
device = x.device
|
device = x.device
|
||||||
|
@ -1584,8 +1506,10 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# pad mask
|
# pad mask
|
||||||
shape[2] = 1
|
shape[2] = 1
|
||||||
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
padding = torch.zeros(shape[:2], dtype=x.dtype, device=x.device)
|
||||||
mask = torch.cat([mask, padding], dim=1)
|
mask = torch.cat([mask, padding], dim=1)
|
||||||
|
|
||||||
|
m = mask.unsqueeze(dim=-1)
|
||||||
|
|
||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
|
|
|
@ -99,8 +99,7 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
|
||||||
continue
|
continue
|
||||||
configs.append( sft )
|
configs.append( sft )
|
||||||
|
|
||||||
if is_windows:
|
configs = [ str(p) for p in configs ]
|
||||||
configs = [ str(p) for p in configs ]
|
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user