From 5c89ffbeea3ba458a865a569f947bf82cca50090 Mon Sep 17 00:00:00 2001 From: sunyt32 Date: Thu, 28 Sep 2023 14:24:37 +0000 Subject: [PATCH] modify rms norm and value dim in retention --- examples/fairseq/models/retnet.py | 40 +++--- torchscale/architecture/config.py | 7 +- torchscale/architecture/retnet.py | 39 ++---- torchscale/component/gate_linear_unit.py | 132 +++++++++++++++++++ torchscale/component/multiscale_retention.py | 25 ++-- torchscale/component/rms_norm.py | 25 ++++ 6 files changed, 208 insertions(+), 60 deletions(-) create mode 100644 torchscale/component/gate_linear_unit.py create mode 100644 torchscale/component/rms_norm.py diff --git a/examples/fairseq/models/retnet.py b/examples/fairseq/models/retnet.py index 94ba8d3..26d3783 100644 --- a/examples/fairseq/models/retnet.py +++ b/examples/fairseq/models/retnet.py @@ -31,8 +31,8 @@ logger = logging.getLogger(__name__) @dataclass class LanguageConfig(FairseqDataclass): - activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( - default="relu", metadata={"help": "activation function to use"} + activation_fn: str = field( + default="swish", metadata={"help": "activation function to use"} ) dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) activation_dropout: float = field( @@ -44,6 +44,9 @@ class LanguageConfig(FairseqDataclass): decoder_embed_dim: int = field( default=512, metadata={"help": "decoder embedding dimension"} ) + decoder_value_embed_dim: int = field( + default=864, metadata={"help": "decoder embedding dimension"} + ) decoder_output_dim: int = field( default=512, metadata={"help": "decoder output dimension"} ) @@ -51,14 +54,14 @@ class LanguageConfig(FairseqDataclass): default=512, metadata={"help": "decoder input dimension"} ) decoder_ffn_embed_dim: int = field( - default=2048, metadata={"help": "decoder embedding dimension for FFN"} + default=864, metadata={"help": "decoder embedding dimension for FFN"} ) decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) decoder_retention_heads: int = field( default=2, metadata={"help": "num decoder retention heads"} ) decoder_normalize_before: bool = field( - default=False, metadata={"help": "apply layernorm before each decoder block"} + default=False, metadata={"help": "apply norm before each decoder block"} ) share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} @@ -67,8 +70,8 @@ class LanguageConfig(FairseqDataclass): default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) - layernorm_embedding: bool = field( - default=False, metadata={"help": "add layernorm to embedding"} + norm_embedding: bool = field( + default=False, metadata={"help": "add norm to embedding"} ) no_scale_embedding: bool = field( default=False, metadata={"help": "if True, dont scale embeddings"} @@ -276,14 +279,15 @@ def retnet_base_architecture(args): args.dropout = getattr(args, "dropout", 0.0) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 864) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 864) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.activation_fn = getattr(args, "activation_fn", "gelu") + args.activation_fn = getattr(args, "activation_fn", "swish") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) @@ -321,7 +325,7 @@ def retnet_base_architecture(args): args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) - args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.norm_embedding = getattr(args, "norm_embedding", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: @@ -330,7 +334,8 @@ def retnet_base_architecture(args): @register_model_architecture("retnet", "retnet_medium") def retnet_medium(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 1728) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1728) args.decoder_layers = getattr(args, "decoder_layers", 16) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4) retnet_base_architecture(args) @@ -338,7 +343,8 @@ def retnet_medium(args): @register_model_architecture("retnet", "retnet_xl") def retnet_xl(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_layers = getattr(args, "decoder_layers", 24) retnet_base_architecture(args) @@ -346,7 +352,8 @@ def retnet_xl(args): @register_model_architecture("retnet", "retnet_3b") def retnet_3b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10) args.decoder_layers = getattr(args, "decoder_layers", 32) retnet_base_architecture(args) @@ -354,7 +361,8 @@ def retnet_3b(args): @register_model_architecture("retnet", "retnet_7b") def retnet_7b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) args.decoder_layers = getattr(args, "decoder_layers", 32) retnet_base_architecture(args) @@ -362,7 +370,8 @@ def retnet_7b(args): @register_model_architecture("retnet", "retnet_13b") def retnet_13b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) args.decoder_layers = getattr(args, "decoder_layers", 40) retnet_base_architecture(args) @@ -370,7 +379,8 @@ def retnet_13b(args): @register_model_architecture("retnet", "retnet_65b") def retnet_65b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384) + args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) args.decoder_layers = getattr(args, "decoder_layers", 64) retnet_base_architecture(args) diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index b961a98..d85a19b 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -212,8 +212,9 @@ class EncoderDecoderConfig(object): class RetNetConfig(object): def __init__(self, **kwargs): self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) + self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280) self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3) - self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1536) + self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280) self.decoder_layers = kwargs.pop("decoder_layers", 12) self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) self.activation_fn = kwargs.pop("activation_fn", "gelu") @@ -221,7 +222,7 @@ class RetNetConfig(object): self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) self.activation_dropout = kwargs.pop("activation_dropout", 0.0) self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) - self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) + self.norm_embedding = kwargs.pop("norm_embedding", False) self.moe_freq = kwargs.pop("moe_freq", 0) self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) self.moe_expert_count = kwargs.pop("moe_expert_count", 0) @@ -244,7 +245,7 @@ class RetNetConfig(object): ) self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.no_output_layer = kwargs.pop("no_output_layer", False) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) + self.norm_eps = kwargs.pop("norm_eps", 1e-6) # Blockwise self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False) self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512) diff --git a/torchscale/architecture/retnet.py b/torchscale/architecture/retnet.py index 786fe87..52dddee 100644 --- a/torchscale/architecture/retnet.py +++ b/torchscale/architecture/retnet.py @@ -11,14 +11,11 @@ from fairscale.nn import checkpoint_wrapper, wrap from torchscale.architecture.utils import init_bert_params from torchscale.component.droppath import DropPath -from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts +from torchscale.component.gate_linear_unit import GLU, make_experts from torchscale.component.multiscale_retention import MultiScaleRetention from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.routing import Top1Gate, Top2Gate -try: - from apex.normalization import FusedLayerNorm as LayerNorm -except ModuleNotFoundError: - from torch.nn import LayerNorm +from torchscale.component.rms_norm import RMSNorm class RetNetRelPos(nn.Module): @@ -91,7 +88,7 @@ class DecoderLayer(nn.Module): self.normalize_before = args.decoder_normalize_before - self.retention_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) + self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.norm_eps) self.is_moe_layer = is_moe_layer self.ffn_dim = args.decoder_ffn_embed_dim @@ -123,7 +120,7 @@ class DecoderLayer(nn.Module): experts = make_experts(args, self.embed_dim, self.ffn_dim) self.moe_layer = MOELayer(gate, experts, args) - self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) + self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.norm_eps) if args.deepnorm: self.alpha = math.pow(2.0 * args.decoder_layers, 0.25) @@ -131,20 +128,19 @@ class DecoderLayer(nn.Module): self.alpha = 1.0 def build_ffn(self, embed_dim, args): - return FeedForwardNetwork( + return GLU( embed_dim, self.ffn_dim, args.activation_fn, args.dropout, args.activation_dropout, - args.layernorm_eps, - args.subln, ) def build_retention(self, embed_dim, args): return MultiScaleRetention( args, embed_dim, + args.decoder_value_embed_dim, args.decoder_retention_heads, ) @@ -224,10 +220,10 @@ class RetNetDecoder(nn.Module): else: self.output_projection = output_projection - if args.layernorm_embedding: - self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps) + if args.norm_embedding: + self.norm_embedding = RMSNorm(embed_dim, eps=args.norm_eps) else: - self.layernorm_embedding = None + self.norm_embedding = None self.layers = nn.ModuleList([]) @@ -245,7 +241,7 @@ class RetNetDecoder(nn.Module): self.num_layers = len(self.layers) if args.decoder_normalize_before: - self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps) + self.layer_norm = RMSNorm(embed_dim, eps=args.norm_eps) else: self.layer_norm = None @@ -265,17 +261,6 @@ class RetNetDecoder(nn.Module): ): p.data.div_(init_scale) - if args.subln: - init_scale = math.sqrt(math.log(args.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.mul_(init_scale) - def build_output_projection( self, args, @@ -324,8 +309,8 @@ class RetNetDecoder(nn.Module): x = embed = self.embed_scale * token_embedding - if self.layernorm_embedding is not None: - x = self.layernorm_embedding(x) + if self.norm_embedding is not None: + x = self.norm_embedding(x) x = self.dropout_module(x) diff --git a/torchscale/component/gate_linear_unit.py b/torchscale/component/gate_linear_unit.py new file mode 100644 index 0000000..1c63a4e --- /dev/null +++ b/torchscale/component/gate_linear_unit.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .xmoe.global_groups import get_moe_group + + +class set_torch_seed(object): + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = self.get_rng_state() + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def get_rng_state(self): + state = {"torch_rng_state": torch.get_rng_state()} + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state + + def set_rng_state(self, state): + torch.set_rng_state(state["torch_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.set_rng_state(self.rng_state) + + +def make_experts(args, embed_dim, expert_ffn_dim): + world_size = ( + 1 + if not torch.distributed.is_initialized() + else torch.distributed.get_world_size() + ) + expert_list = [] + ddp_rank = args.ddp_rank + start_seed = torch.randint(1000000, (1,)).item() + # at least as many experts than gpus + if args.moe_expert_count >= world_size: + assert ( + args.moe_expert_count % world_size == 0 + ), f"{args.moe_expert_count}, {world_size}" + local_moe_expert_count = args.moe_expert_count // world_size + for i in range(local_moe_expert_count): + with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): + expert_list.append( + GLU( + embed_dim, + expert_ffn_dim, + args.activation_fn, + args.dropout, + args.activation_dropout, + args.layernorm_eps, + args.subln, + ) + ) + else: + assert ( + world_size % args.moe_expert_count == 0 + ), f"{world_size}, {args.moe_expert_count}" + + moe_idx, _ = get_moe_group(args.moe_expert_count) + + with set_torch_seed(start_seed + moe_idx): + expert_list.append( + GLU( + embed_dim, + expert_ffn_dim, + args.activation_fn, + args.dropout, + args.activation_dropout, + args.layernorm_eps, + args.subln, + ) + ) + experts = nn.ModuleList(expert_list) + return experts + + +def get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "swish": + return F.silu + else: + raise NotImplementedError + + +class GLU(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False) + self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False) + + def reset_parameters(self): + self.fc1.reset_parameters() + self.fc2.reset_parameters() + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + g = self.gate(x) + x = self.fc1(x) + x = self.activation_fn(x.float()).type_as(x) * g + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = x.view(x_shape) + x = self.dropout_module(x) + return x diff --git a/torchscale/component/multiscale_retention.py b/torchscale/component/multiscale_retention.py index f88e962..0e91235 100644 --- a/torchscale/component/multiscale_retention.py +++ b/torchscale/component/multiscale_retention.py @@ -1,15 +1,11 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] -import math import torch import torch.nn.functional as F from torch import nn -try: - from apex.normalization import FusedLayerNorm as LayerNorm -except ModuleNotFoundError: - from torch.nn import LayerNorm +from .rms_norm import RMSNorm from .multiway_network import MultiwayWrapper @@ -45,29 +41,29 @@ class MultiScaleRetention(nn.Module): self, args, embed_dim, + value_dim, num_heads, - value_factor=2, gate_fn="swish", ): super().__init__() self.args = args - self.factor = value_factor self.embed_dim = embed_dim + self.value_dim = value_dim self.num_heads = num_heads - self.head_dim = self.embed_dim * self.factor // num_heads + self.head_dim = self.value_dim // num_heads self.key_dim = self.embed_dim // num_heads self.scaling = self.key_dim ** -0.5 self.gate_fn = get_activation_fn(activation=str(gate_fn)) - self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) - self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) - self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True)) - self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True)) + self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) + self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) + self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) + self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) - self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim * self.factor, embed_dim, bias=True)) + self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False)) - self.group_norm = MultiwayWrapper(args, LayerNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) + self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.norm_eps, elementwise_affine=False)) self.reset_parameters() def reset_parameters(self): @@ -76,7 +72,6 @@ class MultiScaleRetention(nn.Module): nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.out_proj.weight) - nn.init.constant_(self.out_proj.bias, 0.0) def parallel_forward(self, qr, kr, v, mask): bsz, tgt_len, embed_dim = v.size() diff --git a/torchscale/component/rms_norm.py b/torchscale/component/rms_norm.py new file mode 100644 index 0000000..465536c --- /dev/null +++ b/torchscale/component/rms_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): + super().__init__() + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter('weight', None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + \ No newline at end of file