2024-04-16 15:02:31 +00:00
|
|
|
# https://github.com/syncdoth/RetNet/
|
2024-06-06 01:30:43 +00:00
|
|
|
from ....ext.retnet_ts.config import RetNetConfig
|
|
|
|
from ....ext.retnet_ts.retnet import RetNetModel as RetNetDecoder
|
2024-04-16 15:02:31 +00:00
|
|
|
|
|
|
|
# things we're overriding or required to override
|
2024-06-06 01:30:43 +00:00
|
|
|
from ....ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
|
2024-04-16 15:02:31 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import math
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
|
# required to have compatibile LayerNorm
|
|
|
|
def FeedForwardNetwork_init(
|
|
|
|
self,
|
|
|
|
embed_dim,
|
|
|
|
ffn_dim,
|
|
|
|
activation_fn,
|
|
|
|
dropout,
|
|
|
|
activation_dropout,
|
|
|
|
layernorm_eps,
|
|
|
|
subln=True,
|
|
|
|
use_rms_norm=False,
|
|
|
|
):
|
|
|
|
super(FeedForwardNetwork, self).__init__()
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
|
|
|
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
|
|
|
self.dropout_module = torch.nn.Dropout(dropout)
|
|
|
|
self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim)
|
|
|
|
self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim)
|
|
|
|
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
|
|
|
|
|
|
|
|
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
|
|
|
|
|
|
|
# removes embed_tokens
|
|
|
|
def RetNetModel_init(
|
2024-06-04 05:07:00 +00:00
|
|
|
self, config, embed_tokens=None, output_projection=None, **kwargs
|
|
|
|
):
|
|
|
|
super(RetNetDecoder, self).__init__(**kwargs)
|
|
|
|
self.config = config
|
2024-04-16 15:02:31 +00:00
|
|
|
|
2024-06-04 05:07:00 +00:00
|
|
|
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
|
|
|
|
|
|
self.embed_dim = config.decoder_embed_dim
|
|
|
|
self.embed_scale = (
|
|
|
|
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
|
|
|
)
|
2024-04-16 15:02:31 +00:00
|
|
|
|
2024-06-04 05:07:00 +00:00
|
|
|
if embed_tokens is None and config.vocab_size:
|
|
|
|
embed_tokens = torch.nn.Embedding(
|
|
|
|
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
2024-04-16 15:02:31 +00:00
|
|
|
)
|
2024-06-04 05:07:00 +00:00
|
|
|
self.embed_tokens = embed_tokens
|
2024-04-16 15:02:31 +00:00
|
|
|
|
2024-06-04 05:07:00 +00:00
|
|
|
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
|
|
|
|
self.output_projection = self.build_output_projection(config)
|
|
|
|
else:
|
|
|
|
self.output_projection = output_projection
|
|
|
|
|
|
|
|
if config.layernorm_embedding:
|
|
|
|
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
|
|
else:
|
|
|
|
self.layernorm_embedding = None
|
|
|
|
|
|
|
|
self.layers = torch.nn.ModuleList([])
|
|
|
|
|
|
|
|
for i in range(config.decoder_layers):
|
|
|
|
layer = self.build_decoder_layer(
|
|
|
|
config,
|
|
|
|
depth=i,
|
|
|
|
)
|
2024-04-16 15:02:31 +00:00
|
|
|
"""
|
2024-06-04 05:07:00 +00:00
|
|
|
if config.checkpoint_activations:
|
|
|
|
layer = checkpoint_wrapper(layer)
|
2024-04-16 15:02:31 +00:00
|
|
|
"""
|
2024-06-04 05:07:00 +00:00
|
|
|
self.layers.append(layer)
|
|
|
|
|
|
|
|
self.num_layers = len(self.layers)
|
|
|
|
|
|
|
|
if config.decoder_normalize_before:
|
|
|
|
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
|
|
else:
|
|
|
|
self.layer_norm = None
|
|
|
|
|
|
|
|
self.retnet_rel_pos = RetNetRelPos(config)
|
|
|
|
self.chunkwise_recurrent = config.chunkwise_recurrent
|
|
|
|
self.recurrent_chunk_size = config.recurrent_chunk_size
|
|
|
|
|
|
|
|
if config.deepnorm:
|
|
|
|
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
|
|
|
for name, p in self.named_parameters():
|
|
|
|
if (
|
|
|
|
"fc1" in name
|
|
|
|
or "fc2" in name
|
|
|
|
or "out_proj" in name
|
|
|
|
or "v_proj" in name
|
|
|
|
):
|
|
|
|
p.data.div_(init_scale)
|
|
|
|
|
|
|
|
if config.subln and not config.use_glu:
|
|
|
|
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
|
|
|
for name, p in self.named_parameters():
|
|
|
|
if (
|
|
|
|
"fc1" in name
|
|
|
|
or "fc2" in name
|
|
|
|
or "out_proj" in name
|
|
|
|
or "v_proj" in name
|
|
|
|
):
|
|
|
|
p.data.mul_(init_scale)
|
|
|
|
|
|
|
|
self.gradient_checkpointing = True
|
2024-04-16 15:02:31 +00:00
|
|
|
|
|
|
|
RetNetDecoder.__init__ = RetNetModel_init
|
|
|
|
|
|
|
|
# restores bias in our FFNs
|
|
|
|
def RetNetDecoderLayer_init(
|
|
|
|
self,
|
|
|
|
config,
|
|
|
|
depth,
|
|
|
|
use_bias=True
|
|
|
|
):
|
|
|
|
super(RetNetDecoderLayer, self).__init__()
|
|
|
|
self.config = config
|
|
|
|
self.embed_dim = config.decoder_embed_dim
|
|
|
|
self.dropout_module = torch.nn.Dropout(config.dropout)
|
|
|
|
|
|
|
|
if config.drop_path_rate > 0:
|
|
|
|
drop_path_prob = np.linspace(
|
|
|
|
0, config.drop_path_rate, config.decoder_layers
|
|
|
|
)[depth]
|
|
|
|
self.drop_path = DropPath(drop_path_prob)
|
|
|
|
else:
|
|
|
|
self.drop_path = None
|
|
|
|
|
|
|
|
self.retention = MultiScaleRetention(
|
|
|
|
config,
|
|
|
|
self.embed_dim,
|
|
|
|
config.decoder_value_embed_dim,
|
|
|
|
config.decoder_retention_heads,
|
|
|
|
use_bias=use_bias
|
|
|
|
)
|
|
|
|
|
|
|
|
self.normalize_before = config.decoder_normalize_before
|
|
|
|
|
|
|
|
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
|
|
|
|
|
|
self.ffn_dim = config.decoder_ffn_embed_dim
|
|
|
|
|
|
|
|
self.ffn = self.build_ffn()
|
|
|
|
|
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
|
|
|
|
|
|
|
if config.deepnorm:
|
|
|
|
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
|
|
|
|
else:
|
|
|
|
self.alpha = 1.0
|
|
|
|
|
|
|
|
def RetNetDecoderLayer_forward(
|
|
|
|
self,
|
|
|
|
x,
|
|
|
|
incremental_state=None,
|
|
|
|
chunkwise_recurrent=False,
|
|
|
|
retention_rel_pos=None,
|
|
|
|
):
|
|
|
|
residual = x
|
|
|
|
if self.normalize_before:
|
|
|
|
x = self.retention_layer_norm(x)
|
|
|
|
|
|
|
|
if x.requires_grad and self.config.checkpoint_activations:
|
|
|
|
x = checkpoint(
|
|
|
|
self.retention,
|
|
|
|
x,
|
|
|
|
use_reentrant=False,
|
|
|
|
incremental_state=incremental_state,
|
|
|
|
rel_pos=retention_rel_pos,
|
|
|
|
chunkwise_recurrent=chunkwise_recurrent,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
x = self.retention(
|
|
|
|
x,
|
|
|
|
incremental_state=incremental_state,
|
|
|
|
rel_pos=retention_rel_pos,
|
|
|
|
chunkwise_recurrent=chunkwise_recurrent,
|
|
|
|
)
|
|
|
|
x = self.dropout_module(x)
|
|
|
|
|
|
|
|
if self.drop_path is not None:
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
|
|
|
x = self.residual_connection(x, residual)
|
|
|
|
if not self.normalize_before:
|
|
|
|
x = self.retention_layer_norm(x)
|
|
|
|
|
|
|
|
residual = x
|
|
|
|
if self.normalize_before:
|
|
|
|
x = self.final_layer_norm(x)
|
|
|
|
|
|
|
|
x = self.ffn(x)
|
|
|
|
|
|
|
|
if self.drop_path is not None:
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
|
|
|
x = self.residual_connection(x, residual)
|
|
|
|
if not self.normalize_before:
|
|
|
|
x = self.final_layer_norm(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init
|
|
|
|
RetNetDecoderLayer.forward = RetNetDecoderLayer_forward
|
|
|
|
# fixes backwards when using te's autocast
|
|
|
|
def MultiScaleRetention_init(
|
|
|
|
self,
|
|
|
|
config,
|
|
|
|
embed_dim,
|
|
|
|
value_dim,
|
|
|
|
num_heads,
|
|
|
|
gate_fn="swish",
|
|
|
|
use_bias=True,
|
|
|
|
):
|
|
|
|
super(MultiScaleRetention, self).__init__()
|
|
|
|
self.config = config
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
self.value_dim = value_dim
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.head_dim = self.value_dim // num_heads
|
|
|
|
self.key_dim = self.embed_dim // num_heads
|
|
|
|
self.scaling = self.key_dim**-0.5
|
|
|
|
|
|
|
|
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
|
|
|
|
|
|
|
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
|
|
|
|
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
|
|
|
|
self.v_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
|
|
|
|
self.g_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
|
|
|
|
|
|
|
|
self.out_proj = torch.nn.Linear(value_dim, embed_dim, bias=use_bias)
|
|
|
|
|
|
|
|
self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False)
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
|
|
def MultiScaleRetention_forward(
|
|
|
|
self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None
|
|
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
|
|
|
|
bsz, tgt_len, _ = x.size()
|
|
|
|
(sin, cos), inner_mask = rel_pos
|
|
|
|
|
|
|
|
q = self.q_proj(x)
|
|
|
|
k = self.k_proj(x) * self.scaling
|
|
|
|
v = self.v_proj(x)
|
|
|
|
g = self.g_proj(x)
|
|
|
|
|
|
|
|
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
|
|
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
qr = theta_shift(q, sin, cos)
|
|
|
|
kr = theta_shift(k, sin, cos)
|
|
|
|
|
|
|
|
if incremental_state is not None:
|
|
|
|
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
|
|
|
|
elif chunkwise_recurrent:
|
|
|
|
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
|
|
|
|
else:
|
|
|
|
output = self.parallel_forward(qr, kr, v, inner_mask)
|
|
|
|
|
|
|
|
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
|
|
|
|
|
|
|
|
output = self.gate_fn(g) * output
|
|
|
|
|
|
|
|
output = self.out_proj(output)
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
MultiScaleRetention.__init__ = MultiScaleRetention_init
|
|
|
|
MultiScaleRetention.forward = MultiScaleRetention_forward
|