diff --git a/README.md b/README.md index e757cbb..a7bd11a 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Fundamental research to develop new architectures for foundation models and A(G) ## News +- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet - November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)] ## Installation diff --git a/examples/fairseq/models/retnet.py b/examples/fairseq/models/retnet.py index 94ba8d3..1a5b329 100644 --- a/examples/fairseq/models/retnet.py +++ b/examples/fairseq/models/retnet.py @@ -31,8 +31,8 @@ logger = logging.getLogger(__name__) @dataclass class LanguageConfig(FairseqDataclass): - activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( - default="relu", metadata={"help": "activation function to use"} + activation_fn: str = field( + default="swish", metadata={"help": "activation function to use"} ) dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) activation_dropout: float = field( @@ -44,6 +44,9 @@ class LanguageConfig(FairseqDataclass): decoder_embed_dim: int = field( default=512, metadata={"help": "decoder embedding dimension"} ) + decoder_value_embed_dim: int = field( + default=864, metadata={"help": "decoder embedding dimension"} + ) decoder_output_dim: int = field( default=512, metadata={"help": "decoder output dimension"} ) @@ -51,14 +54,14 @@ class LanguageConfig(FairseqDataclass): default=512, metadata={"help": "decoder input dimension"} ) decoder_ffn_embed_dim: int = field( - default=2048, metadata={"help": "decoder embedding dimension for FFN"} + default=864, metadata={"help": "decoder embedding dimension for FFN"} ) decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) decoder_retention_heads: int = field( default=2, metadata={"help": "num decoder retention heads"} ) decoder_normalize_before: bool = field( - default=False, metadata={"help": "apply layernorm before each decoder block"} + default=False, metadata={"help": "apply norm before each decoder block"} ) share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} @@ -68,7 +71,7 @@ class LanguageConfig(FairseqDataclass): metadata={"help": "use learned positional embeddings in the decoder"}, ) layernorm_embedding: bool = field( - default=False, metadata={"help": "add layernorm to embedding"} + default=False, metadata={"help": "add norm to embedding"} ) no_scale_embedding: bool = field( default=False, metadata={"help": "if True, dont scale embeddings"} @@ -276,14 +279,15 @@ def retnet_base_architecture(args): args.dropout = getattr(args, "dropout", 0.0) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 864) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 864) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.activation_fn = getattr(args, "activation_fn", "gelu") + args.activation_fn = getattr(args, "activation_fn", "swish") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) @@ -330,7 +334,8 @@ def retnet_base_architecture(args): @register_model_architecture("retnet", "retnet_medium") def retnet_medium(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 1728) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1728) args.decoder_layers = getattr(args, "decoder_layers", 16) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4) retnet_base_architecture(args) @@ -338,7 +343,8 @@ def retnet_medium(args): @register_model_architecture("retnet", "retnet_xl") def retnet_xl(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_layers = getattr(args, "decoder_layers", 24) retnet_base_architecture(args) @@ -346,7 +352,8 @@ def retnet_xl(args): @register_model_architecture("retnet", "retnet_3b") def retnet_3b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10) args.decoder_layers = getattr(args, "decoder_layers", 32) retnet_base_architecture(args) @@ -354,7 +361,8 @@ def retnet_3b(args): @register_model_architecture("retnet", "retnet_7b") def retnet_7b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) args.decoder_layers = getattr(args, "decoder_layers", 32) retnet_base_architecture(args) @@ -362,7 +370,8 @@ def retnet_7b(args): @register_model_architecture("retnet", "retnet_13b") def retnet_13b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) args.decoder_layers = getattr(args, "decoder_layers", 40) retnet_base_architecture(args) @@ -370,7 +379,8 @@ def retnet_13b(args): @register_model_architecture("retnet", "retnet_65b") def retnet_65b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) args.decoder_layers = getattr(args, "decoder_layers", 64) retnet_base_architecture(args) diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index b961a98..5898f39 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -212,8 +212,9 @@ class EncoderDecoderConfig(object): class RetNetConfig(object): def __init__(self, **kwargs): self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) + self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280) self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3) - self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1536) + self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280) self.decoder_layers = kwargs.pop("decoder_layers", 12) self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) self.activation_fn = kwargs.pop("activation_fn", "gelu") @@ -244,7 +245,7 @@ class RetNetConfig(object): ) self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.no_output_layer = kwargs.pop("no_output_layer", False) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6) # Blockwise self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False) self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512) diff --git a/torchscale/architecture/retnet.py b/torchscale/architecture/retnet.py index 786fe87..b29928c 100644 --- a/torchscale/architecture/retnet.py +++ b/torchscale/architecture/retnet.py @@ -11,14 +11,12 @@ from fairscale.nn import checkpoint_wrapper, wrap from torchscale.architecture.utils import init_bert_params from torchscale.component.droppath import DropPath -from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts +from torchscale.component.feedforward_network import make_experts +from torchscale.component.gate_linear_unit import GLU from torchscale.component.multiscale_retention import MultiScaleRetention from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.routing import Top1Gate, Top2Gate -try: - from apex.normalization import FusedLayerNorm as LayerNorm -except ModuleNotFoundError: - from torch.nn import LayerNorm +from torchscale.component.rms_norm import RMSNorm class RetNetRelPos(nn.Module): @@ -46,14 +44,17 @@ class RetNetRelPos(nn.Module): mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf")) mask = torch.exp(mask * self.decay[:, None, None]) mask = torch.nan_to_num(mask) + + value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True) + value_inner_decay = value_inner_decay.unsqueeze(-1) scale = mask.sum(dim=-1, keepdim=True).sqrt() - mask = mask / scale + inner_mask = mask / scale cross_decay = torch.exp(self.decay * self.recurrent_chunk_size) - inner_decay = torch.exp(self.decay[:, None] * (block_index + 1)) + query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1)) + query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None]) cross_decay = cross_decay[:, None, None] - inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None]) - retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay)) + retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay)) else: index = torch.arange(slen).to(self.decay) sin = torch.sin(index[:, None] * self.angle[None, :]) @@ -91,7 +92,7 @@ class DecoderLayer(nn.Module): self.normalize_before = args.decoder_normalize_before - self.retention_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) + self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps) self.is_moe_layer = is_moe_layer self.ffn_dim = args.decoder_ffn_embed_dim @@ -123,7 +124,7 @@ class DecoderLayer(nn.Module): experts = make_experts(args, self.embed_dim, self.ffn_dim) self.moe_layer = MOELayer(gate, experts, args) - self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) + self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps) if args.deepnorm: self.alpha = math.pow(2.0 * args.decoder_layers, 0.25) @@ -131,20 +132,19 @@ class DecoderLayer(nn.Module): self.alpha = 1.0 def build_ffn(self, embed_dim, args): - return FeedForwardNetwork( + return GLU( embed_dim, self.ffn_dim, args.activation_fn, args.dropout, args.activation_dropout, - args.layernorm_eps, - args.subln, ) def build_retention(self, embed_dim, args): return MultiScaleRetention( args, embed_dim, + args.decoder_value_embed_dim, args.decoder_retention_heads, ) @@ -225,7 +225,7 @@ class RetNetDecoder(nn.Module): self.output_projection = output_projection if args.layernorm_embedding: - self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps) + self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps) else: self.layernorm_embedding = None @@ -245,7 +245,7 @@ class RetNetDecoder(nn.Module): self.num_layers = len(self.layers) if args.decoder_normalize_before: - self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps) + self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps) else: self.layer_norm = None @@ -265,17 +265,6 @@ class RetNetDecoder(nn.Module): ): p.data.div_(init_scale) - if args.subln: - init_scale = math.sqrt(math.log(args.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.mul_(init_scale) - def build_output_projection( self, args, @@ -360,7 +349,6 @@ class RetNetDecoder(nn.Module): slen = prev_output_tokens.size(1) # relative position retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent) - # decoder layers inner_states = [x] @@ -374,7 +362,7 @@ class RetNetDecoder(nn.Module): else: if idx not in incremental_state: incremental_state[idx] = {} - + x, l_aux_i = layer( x, incremental_state[idx] if incremental_state is not None else None, diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index cc187a8..9d0295d 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -96,6 +96,8 @@ def get_activation_fn(activation): return F.relu elif activation == "gelu": return F.gelu + elif activation == "swish": + return F.silu else: raise NotImplementedError diff --git a/torchscale/component/gate_linear_unit.py b/torchscale/component/gate_linear_unit.py new file mode 100644 index 0000000..ecc9b34 --- /dev/null +++ b/torchscale/component/gate_linear_unit.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .feedforward_network import get_activation_fn + + +class GLU(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False) + self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False) + + def reset_parameters(self): + self.fc1.reset_parameters() + self.fc2.reset_parameters() + self.gate.reset_parameters() + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + g = self.gate(x) + x = self.fc1(x) + x = self.activation_fn(x.float()).type_as(x) * g + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = x.view(x_shape) + x = self.dropout_module(x) + return x diff --git a/torchscale/component/multiscale_retention.py b/torchscale/component/multiscale_retention.py index f88e962..c481b23 100644 --- a/torchscale/component/multiscale_retention.py +++ b/torchscale/component/multiscale_retention.py @@ -1,15 +1,11 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] -import math import torch import torch.nn.functional as F from torch import nn -try: - from apex.normalization import FusedLayerNorm as LayerNorm -except ModuleNotFoundError: - from torch.nn import LayerNorm +from .rms_norm import RMSNorm from .multiway_network import MultiwayWrapper @@ -45,29 +41,29 @@ class MultiScaleRetention(nn.Module): self, args, embed_dim, + value_dim, num_heads, - value_factor=2, gate_fn="swish", ): super().__init__() self.args = args - self.factor = value_factor self.embed_dim = embed_dim + self.value_dim = value_dim self.num_heads = num_heads - self.head_dim = self.embed_dim * self.factor // num_heads + self.head_dim = self.value_dim // num_heads self.key_dim = self.embed_dim // num_heads self.scaling = self.key_dim ** -0.5 self.gate_fn = get_activation_fn(activation=str(gate_fn)) - self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) - self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) - self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True)) - self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True)) + self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) + self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) + self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) + self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) - self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim * self.factor, embed_dim, bias=True)) + self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False)) - self.group_norm = MultiwayWrapper(args, LayerNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) + self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) self.reset_parameters() def reset_parameters(self): @@ -76,7 +72,6 @@ class MultiScaleRetention(nn.Module): nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.out_proj.weight) - nn.init.constant_(self.out_proj.bias, 0.0) def parallel_forward(self, qr, kr, v, mask): bsz, tgt_len, embed_dim = v.size() @@ -121,7 +116,7 @@ class MultiScaleRetention(nn.Module): qr, kr, v, inner_mask ): - mask, cross_decay, inner_decay = inner_mask + mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask bsz, tgt_len, embed_dim = v.size() chunk_len = mask.size(1) num_chunks = tgt_len // chunk_len @@ -141,8 +136,7 @@ class MultiScaleRetention(nn.Module): inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim # reduce kv in one chunk - kv = kr_t @ (v * mask[:, -1, :, None]) - kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim) + kv = kr_t @ (v * value_inner_decay) kv_recurrent = [] cross_scale = [] @@ -163,7 +157,7 @@ class MultiScaleRetention(nn.Module): align_inner_scale = all_scale / inner_scale align_cross_scale = all_scale / cross_scale - cross_output = (qr * inner_decay) @ kv_recurrent + cross_output = (qr * query_inner_decay) @ kv_recurrent output = inner_output / align_inner_scale + cross_output / align_cross_scale # output = inner_output / cross_scale + cross_output / inner_scale diff --git a/torchscale/component/rms_norm.py b/torchscale/component/rms_norm.py new file mode 100644 index 0000000..465536c --- /dev/null +++ b/torchscale/component/rms_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): + super().__init__() + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter('weight', None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + \ No newline at end of file