vall-e/vall_e/models/retnet_ts.py

279 lines
7.5 KiB
Python

# https://github.com/syncdoth/RetNet/
from ..ext.retnet_ts.config import RetNetConfig
from ..ext.retnet_ts.retnet import RetNetModel as RetNetDecoder
# things we're overriding or required to override
from ..ext.retnet_ts.retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
import torch
import math
from typing import Dict, List, Optional, Tuple, Union
from torch.utils.checkpoint import checkpoint
# required to have compatibile LayerNorm
def FeedForwardNetwork_init(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
layernorm_eps,
subln=True,
use_rms_norm=False,
):
super(FeedForwardNetwork, self).__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim)
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
# removes embed_tokens
def RetNetModel_init(
self, config, embed_tokens=None, output_projection=None, **kwargs
):
super(RetNetDecoder, self).__init__(**kwargs)
self.config = config
self.dropout_module = torch.nn.Dropout(config.dropout)
self.embed_dim = config.decoder_embed_dim
self.embed_scale = (
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
)
"""
if embed_tokens is None:
embed_tokens = torch.nn.Embedding(
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
)
"""
self.embed_tokens = None
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
self.output_projection = self.build_output_projection(config)
else:
self.output_projection = output_projection
if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layernorm_embedding = None
self.layers = torch.nn.ModuleList([])
for i in range(config.decoder_layers):
layer = self.build_decoder_layer(
config,
depth=i,
)
"""
if config.checkpoint_activations:
layer = checkpoint_wrapper(layer)
"""
self.layers.append(layer)
self.num_layers = len(self.layers)
if config.decoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(config)
self.chunkwise_recurrent = config.chunkwise_recurrent
self.recurrent_chunk_size = config.recurrent_chunk_size
if config.deepnorm:
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
if config.subln and not config.use_glu:
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
self.gradient_checkpointing = True
RetNetDecoder.__init__ = RetNetModel_init
# restores bias in our FFNs
def RetNetDecoderLayer_init(
self,
config,
depth,
use_bias=True
):
super(RetNetDecoderLayer, self).__init__()
self.config = config
self.embed_dim = config.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(config.dropout)
if config.drop_path_rate > 0:
drop_path_prob = np.linspace(
0, config.drop_path_rate, config.decoder_layers
)[depth]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.retention = MultiScaleRetention(
config,
self.embed_dim,
config.decoder_value_embed_dim,
config.decoder_retention_heads,
use_bias=use_bias
)
self.normalize_before = config.decoder_normalize_before
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
self.ffn_dim = config.decoder_ffn_embed_dim
self.ffn = self.build_ffn()
self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
if config.deepnorm:
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
else:
self.alpha = 1.0
def RetNetDecoderLayer_forward(
self,
x,
incremental_state=None,
chunkwise_recurrent=False,
retention_rel_pos=None,
):
residual = x
if self.normalize_before:
x = self.retention_layer_norm(x)
if x.requires_grad and self.config.checkpoint_activations:
x = checkpoint(
self.retention,
x,
use_reentrant=False,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
else:
x = self.retention(
x,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.retention_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.ffn(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init
RetNetDecoderLayer.forward = RetNetDecoderLayer_forward
# fixes backwards when using te's autocast
def MultiScaleRetention_init(
self,
config,
embed_dim,
value_dim,
num_heads,
gate_fn="swish",
use_bias=True,
):
super(MultiScaleRetention, self).__init__()
self.config = config
self.embed_dim = embed_dim
self.value_dim = value_dim
self.num_heads = num_heads
self.head_dim = self.value_dim // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim**-0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=use_bias)
self.v_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
self.g_proj = torch.nn.Linear(embed_dim, value_dim, bias=use_bias)
self.out_proj = torch.nn.Linear(value_dim, embed_dim, bias=use_bias)
self.group_norm = RMSNorm(self.head_dim, eps=config.layernorm_eps, elementwise_affine=False)
self.reset_parameters()
def MultiScaleRetention_forward(
self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x) * self.scaling
v = self.v_proj(x)
g = self.g_proj(x)
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output
MultiScaleRetention.__init__ = MultiScaleRetention_init
MultiScaleRetention.forward = MultiScaleRetention_forward