Merge pull request #69 from sunyt32/retnet-official
Update new RetNet settings
This commit is contained in:
commit
ab1d9d677a
|
@ -19,6 +19,7 @@ Fundamental research to develop new architectures for foundation models and A(G)
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
|
- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
|
||||||
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
|
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
|
@ -31,8 +31,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LanguageConfig(FairseqDataclass):
|
class LanguageConfig(FairseqDataclass):
|
||||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
activation_fn: str = field(
|
||||||
default="relu", metadata={"help": "activation function to use"}
|
default="swish", metadata={"help": "activation function to use"}
|
||||||
)
|
)
|
||||||
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
||||||
activation_dropout: float = field(
|
activation_dropout: float = field(
|
||||||
|
@ -44,6 +44,9 @@ class LanguageConfig(FairseqDataclass):
|
||||||
decoder_embed_dim: int = field(
|
decoder_embed_dim: int = field(
|
||||||
default=512, metadata={"help": "decoder embedding dimension"}
|
default=512, metadata={"help": "decoder embedding dimension"}
|
||||||
)
|
)
|
||||||
|
decoder_value_embed_dim: int = field(
|
||||||
|
default=864, metadata={"help": "decoder embedding dimension"}
|
||||||
|
)
|
||||||
decoder_output_dim: int = field(
|
decoder_output_dim: int = field(
|
||||||
default=512, metadata={"help": "decoder output dimension"}
|
default=512, metadata={"help": "decoder output dimension"}
|
||||||
)
|
)
|
||||||
|
@ -51,14 +54,14 @@ class LanguageConfig(FairseqDataclass):
|
||||||
default=512, metadata={"help": "decoder input dimension"}
|
default=512, metadata={"help": "decoder input dimension"}
|
||||||
)
|
)
|
||||||
decoder_ffn_embed_dim: int = field(
|
decoder_ffn_embed_dim: int = field(
|
||||||
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
|
default=864, metadata={"help": "decoder embedding dimension for FFN"}
|
||||||
)
|
)
|
||||||
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
|
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
|
||||||
decoder_retention_heads: int = field(
|
decoder_retention_heads: int = field(
|
||||||
default=2, metadata={"help": "num decoder retention heads"}
|
default=2, metadata={"help": "num decoder retention heads"}
|
||||||
)
|
)
|
||||||
decoder_normalize_before: bool = field(
|
decoder_normalize_before: bool = field(
|
||||||
default=False, metadata={"help": "apply layernorm before each decoder block"}
|
default=False, metadata={"help": "apply norm before each decoder block"}
|
||||||
)
|
)
|
||||||
share_decoder_input_output_embed: bool = field(
|
share_decoder_input_output_embed: bool = field(
|
||||||
default=False, metadata={"help": "share decoder input and output embeddings"}
|
default=False, metadata={"help": "share decoder input and output embeddings"}
|
||||||
|
@ -68,7 +71,7 @@ class LanguageConfig(FairseqDataclass):
|
||||||
metadata={"help": "use learned positional embeddings in the decoder"},
|
metadata={"help": "use learned positional embeddings in the decoder"},
|
||||||
)
|
)
|
||||||
layernorm_embedding: bool = field(
|
layernorm_embedding: bool = field(
|
||||||
default=False, metadata={"help": "add layernorm to embedding"}
|
default=False, metadata={"help": "add norm to embedding"}
|
||||||
)
|
)
|
||||||
no_scale_embedding: bool = field(
|
no_scale_embedding: bool = field(
|
||||||
default=False, metadata={"help": "if True, dont scale embeddings"}
|
default=False, metadata={"help": "if True, dont scale embeddings"}
|
||||||
|
@ -276,14 +279,15 @@ def retnet_base_architecture(args):
|
||||||
args.dropout = getattr(args, "dropout", 0.0)
|
args.dropout = getattr(args, "dropout", 0.0)
|
||||||
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 864)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 864)
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
|
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
|
||||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||||
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
|
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
|
||||||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
||||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
args.activation_fn = getattr(args, "activation_fn", "swish")
|
||||||
|
|
||||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
||||||
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
||||||
|
@ -330,7 +334,8 @@ def retnet_base_architecture(args):
|
||||||
@register_model_architecture("retnet", "retnet_medium")
|
@register_model_architecture("retnet", "retnet_medium")
|
||||||
def retnet_medium(args):
|
def retnet_medium(args):
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
|
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 1728)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1728)
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 16)
|
args.decoder_layers = getattr(args, "decoder_layers", 16)
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
|
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
|
||||||
retnet_base_architecture(args)
|
retnet_base_architecture(args)
|
||||||
|
@ -338,7 +343,8 @@ def retnet_medium(args):
|
||||||
@register_model_architecture("retnet", "retnet_xl")
|
@register_model_architecture("retnet", "retnet_xl")
|
||||||
def retnet_xl(args):
|
def retnet_xl(args):
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456)
|
||||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
||||||
retnet_base_architecture(args)
|
retnet_base_architecture(args)
|
||||||
|
@ -346,7 +352,8 @@ def retnet_xl(args):
|
||||||
@register_model_architecture("retnet", "retnet_3b")
|
@register_model_architecture("retnet", "retnet_3b")
|
||||||
def retnet_3b(args):
|
def retnet_3b(args):
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120)
|
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280)
|
||||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10)
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10)
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 32)
|
args.decoder_layers = getattr(args, "decoder_layers", 32)
|
||||||
retnet_base_architecture(args)
|
retnet_base_architecture(args)
|
||||||
|
@ -354,7 +361,8 @@ def retnet_3b(args):
|
||||||
@register_model_architecture("retnet", "retnet_7b")
|
@register_model_architecture("retnet", "retnet_7b")
|
||||||
def retnet_7b(args):
|
def retnet_7b(args):
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
|
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912)
|
||||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 32)
|
args.decoder_layers = getattr(args, "decoder_layers", 32)
|
||||||
retnet_base_architecture(args)
|
retnet_base_architecture(args)
|
||||||
|
@ -362,7 +370,8 @@ def retnet_7b(args):
|
||||||
@register_model_architecture("retnet", "retnet_13b")
|
@register_model_architecture("retnet", "retnet_13b")
|
||||||
def retnet_13b(args):
|
def retnet_13b(args):
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240)
|
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560)
|
||||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 40)
|
args.decoder_layers = getattr(args, "decoder_layers", 40)
|
||||||
retnet_base_architecture(args)
|
retnet_base_architecture(args)
|
||||||
|
@ -370,7 +379,8 @@ def retnet_13b(args):
|
||||||
@register_model_architecture("retnet", "retnet_65b")
|
@register_model_architecture("retnet", "retnet_65b")
|
||||||
def retnet_65b(args):
|
def retnet_65b(args):
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384)
|
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824)
|
||||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 64)
|
args.decoder_layers = getattr(args, "decoder_layers", 64)
|
||||||
retnet_base_architecture(args)
|
retnet_base_architecture(args)
|
||||||
|
|
|
@ -212,8 +212,9 @@ class EncoderDecoderConfig(object):
|
||||||
class RetNetConfig(object):
|
class RetNetConfig(object):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||||
|
self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
|
||||||
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
|
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
|
||||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1536)
|
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
|
||||||
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
||||||
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
||||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||||
|
@ -244,7 +245,7 @@ class RetNetConfig(object):
|
||||||
)
|
)
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
|
||||||
# Blockwise
|
# Blockwise
|
||||||
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
|
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
|
||||||
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
|
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
|
||||||
|
|
|
@ -11,14 +11,12 @@ from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
from torchscale.architecture.utils import init_bert_params
|
||||||
from torchscale.component.droppath import DropPath
|
from torchscale.component.droppath import DropPath
|
||||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
from torchscale.component.feedforward_network import make_experts
|
||||||
|
from torchscale.component.gate_linear_unit import GLU
|
||||||
from torchscale.component.multiscale_retention import MultiScaleRetention
|
from torchscale.component.multiscale_retention import MultiScaleRetention
|
||||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
try:
|
from torchscale.component.rms_norm import RMSNorm
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from torch.nn import LayerNorm
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetRelPos(nn.Module):
|
class RetNetRelPos(nn.Module):
|
||||||
|
@ -46,14 +44,17 @@ class RetNetRelPos(nn.Module):
|
||||||
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
|
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
|
||||||
mask = torch.exp(mask * self.decay[:, None, None])
|
mask = torch.exp(mask * self.decay[:, None, None])
|
||||||
mask = torch.nan_to_num(mask)
|
mask = torch.nan_to_num(mask)
|
||||||
|
|
||||||
|
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
|
||||||
|
value_inner_decay = value_inner_decay.unsqueeze(-1)
|
||||||
scale = mask.sum(dim=-1, keepdim=True).sqrt()
|
scale = mask.sum(dim=-1, keepdim=True).sqrt()
|
||||||
mask = mask / scale
|
inner_mask = mask / scale
|
||||||
|
|
||||||
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
|
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
|
||||||
inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
||||||
|
query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
|
||||||
cross_decay = cross_decay[:, None, None]
|
cross_decay = cross_decay[:, None, None]
|
||||||
inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])
|
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
|
||||||
retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
|
|
||||||
else:
|
else:
|
||||||
index = torch.arange(slen).to(self.decay)
|
index = torch.arange(slen).to(self.decay)
|
||||||
sin = torch.sin(index[:, None] * self.angle[None, :])
|
sin = torch.sin(index[:, None] * self.angle[None, :])
|
||||||
|
@ -91,7 +92,7 @@ class DecoderLayer(nn.Module):
|
||||||
|
|
||||||
self.normalize_before = args.decoder_normalize_before
|
self.normalize_before = args.decoder_normalize_before
|
||||||
|
|
||||||
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||||
|
|
||||||
self.is_moe_layer = is_moe_layer
|
self.is_moe_layer = is_moe_layer
|
||||||
self.ffn_dim = args.decoder_ffn_embed_dim
|
self.ffn_dim = args.decoder_ffn_embed_dim
|
||||||
|
@ -123,7 +124,7 @@ class DecoderLayer(nn.Module):
|
||||||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||||
self.moe_layer = MOELayer(gate, experts, args)
|
self.moe_layer = MOELayer(gate, experts, args)
|
||||||
|
|
||||||
self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||||
|
|
||||||
if args.deepnorm:
|
if args.deepnorm:
|
||||||
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
||||||
|
@ -131,20 +132,19 @@ class DecoderLayer(nn.Module):
|
||||||
self.alpha = 1.0
|
self.alpha = 1.0
|
||||||
|
|
||||||
def build_ffn(self, embed_dim, args):
|
def build_ffn(self, embed_dim, args):
|
||||||
return FeedForwardNetwork(
|
return GLU(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
self.ffn_dim,
|
self.ffn_dim,
|
||||||
args.activation_fn,
|
args.activation_fn,
|
||||||
args.dropout,
|
args.dropout,
|
||||||
args.activation_dropout,
|
args.activation_dropout,
|
||||||
args.layernorm_eps,
|
|
||||||
args.subln,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_retention(self, embed_dim, args):
|
def build_retention(self, embed_dim, args):
|
||||||
return MultiScaleRetention(
|
return MultiScaleRetention(
|
||||||
args,
|
args,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
args.decoder_value_embed_dim,
|
||||||
args.decoder_retention_heads,
|
args.decoder_retention_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ class RetNetDecoder(nn.Module):
|
||||||
self.output_projection = output_projection
|
self.output_projection = output_projection
|
||||||
|
|
||||||
if args.layernorm_embedding:
|
if args.layernorm_embedding:
|
||||||
self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
||||||
else:
|
else:
|
||||||
self.layernorm_embedding = None
|
self.layernorm_embedding = None
|
||||||
|
|
||||||
|
@ -245,7 +245,7 @@ class RetNetDecoder(nn.Module):
|
||||||
self.num_layers = len(self.layers)
|
self.num_layers = len(self.layers)
|
||||||
|
|
||||||
if args.decoder_normalize_before:
|
if args.decoder_normalize_before:
|
||||||
self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
||||||
else:
|
else:
|
||||||
self.layer_norm = None
|
self.layer_norm = None
|
||||||
|
|
||||||
|
@ -265,17 +265,6 @@ class RetNetDecoder(nn.Module):
|
||||||
):
|
):
|
||||||
p.data.div_(init_scale)
|
p.data.div_(init_scale)
|
||||||
|
|
||||||
if args.subln:
|
|
||||||
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if (
|
|
||||||
"fc1" in name
|
|
||||||
or "fc2" in name
|
|
||||||
or "out_proj" in name
|
|
||||||
or "v_proj" in name
|
|
||||||
):
|
|
||||||
p.data.mul_(init_scale)
|
|
||||||
|
|
||||||
def build_output_projection(
|
def build_output_projection(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
|
@ -360,7 +349,6 @@ class RetNetDecoder(nn.Module):
|
||||||
slen = prev_output_tokens.size(1)
|
slen = prev_output_tokens.size(1)
|
||||||
# relative position
|
# relative position
|
||||||
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
inner_states = [x]
|
inner_states = [x]
|
||||||
|
|
||||||
|
@ -374,7 +362,7 @@ class RetNetDecoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
if idx not in incremental_state:
|
if idx not in incremental_state:
|
||||||
incremental_state[idx] = {}
|
incremental_state[idx] = {}
|
||||||
|
|
||||||
x, l_aux_i = layer(
|
x, l_aux_i = layer(
|
||||||
x,
|
x,
|
||||||
incremental_state[idx] if incremental_state is not None else None,
|
incremental_state[idx] if incremental_state is not None else None,
|
||||||
|
|
|
@ -96,6 +96,8 @@ def get_activation_fn(activation):
|
||||||
return F.relu
|
return F.relu
|
||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
return F.gelu
|
return F.gelu
|
||||||
|
elif activation == "swish":
|
||||||
|
return F.silu
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
44
torchscale/component/gate_linear_unit.py
Normal file
44
torchscale/component/gate_linear_unit.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .feedforward_network import get_activation_fn
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
ffn_dim,
|
||||||
|
activation_fn,
|
||||||
|
dropout,
|
||||||
|
activation_dropout,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
||||||
|
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
||||||
|
self.dropout_module = torch.nn.Dropout(dropout)
|
||||||
|
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||||
|
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
||||||
|
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.fc1.reset_parameters()
|
||||||
|
self.fc2.reset_parameters()
|
||||||
|
self.gate.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.reshape(-1, x.size(-1))
|
||||||
|
g = self.gate(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.activation_fn(x.float()).type_as(x) * g
|
||||||
|
x = self.activation_dropout_module(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = x.view(x_shape)
|
||||||
|
x = self.dropout_module(x)
|
||||||
|
return x
|
|
@ -1,15 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
try:
|
from .rms_norm import RMSNorm
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from torch.nn import LayerNorm
|
|
||||||
|
|
||||||
from .multiway_network import MultiwayWrapper
|
from .multiway_network import MultiwayWrapper
|
||||||
|
|
||||||
|
@ -45,29 +41,29 @@ class MultiScaleRetention(nn.Module):
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
value_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
value_factor=2,
|
|
||||||
gate_fn="swish",
|
gate_fn="swish",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.factor = value_factor
|
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
self.value_dim = value_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = self.embed_dim * self.factor // num_heads
|
self.head_dim = self.value_dim // num_heads
|
||||||
self.key_dim = self.embed_dim // num_heads
|
self.key_dim = self.embed_dim // num_heads
|
||||||
self.scaling = self.key_dim ** -0.5
|
self.scaling = self.key_dim ** -0.5
|
||||||
|
|
||||||
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
||||||
|
|
||||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
|
||||||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
|
||||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True))
|
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
|
||||||
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True))
|
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
|
||||||
|
|
||||||
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim * self.factor, embed_dim, bias=True))
|
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
|
||||||
|
|
||||||
self.group_norm = MultiwayWrapper(args, LayerNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
|
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
|
@ -76,7 +72,6 @@ class MultiScaleRetention(nn.Module):
|
||||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
|
||||||
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
|
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
|
||||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
|
||||||
|
|
||||||
def parallel_forward(self, qr, kr, v, mask):
|
def parallel_forward(self, qr, kr, v, mask):
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
bsz, tgt_len, embed_dim = v.size()
|
||||||
|
@ -121,7 +116,7 @@ class MultiScaleRetention(nn.Module):
|
||||||
qr, kr, v,
|
qr, kr, v,
|
||||||
inner_mask
|
inner_mask
|
||||||
):
|
):
|
||||||
mask, cross_decay, inner_decay = inner_mask
|
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
bsz, tgt_len, embed_dim = v.size()
|
||||||
chunk_len = mask.size(1)
|
chunk_len = mask.size(1)
|
||||||
num_chunks = tgt_len // chunk_len
|
num_chunks = tgt_len // chunk_len
|
||||||
|
@ -141,8 +136,7 @@ class MultiScaleRetention(nn.Module):
|
||||||
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
|
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
|
||||||
|
|
||||||
# reduce kv in one chunk
|
# reduce kv in one chunk
|
||||||
kv = kr_t @ (v * mask[:, -1, :, None])
|
kv = kr_t @ (v * value_inner_decay)
|
||||||
kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
|
|
||||||
|
|
||||||
kv_recurrent = []
|
kv_recurrent = []
|
||||||
cross_scale = []
|
cross_scale = []
|
||||||
|
@ -163,7 +157,7 @@ class MultiScaleRetention(nn.Module):
|
||||||
align_inner_scale = all_scale / inner_scale
|
align_inner_scale = all_scale / inner_scale
|
||||||
align_cross_scale = all_scale / cross_scale
|
align_cross_scale = all_scale / cross_scale
|
||||||
|
|
||||||
cross_output = (qr * inner_decay) @ kv_recurrent
|
cross_output = (qr * query_inner_decay) @ kv_recurrent
|
||||||
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
||||||
# output = inner_output / cross_scale + cross_output / inner_scale
|
# output = inner_output / cross_scale + cross_output / inner_scale
|
||||||
|
|
||||||
|
|
25
torchscale/component/rms_norm.py
Normal file
25
torchscale/component/rms_norm.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
if self.weight is not None:
|
||||||
|
output = output * self.weight
|
||||||
|
return output
|
||||||
|
|
Loading…
Reference in New Issue
Block a user