diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index b267dd8..3a088bc 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -261,6 +261,11 @@ class RetNetConfig(object): # RetNet's RelPos base self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000) + # Backwards compatibility flags + self.use_layernorm = kwargs.pop("use_layernorm", False) + self.use_biases = kwargs.pop("use_biases", False) + self.use_glu = kwargs.pop("use_glu", True) + if self.deepnorm: self.decoder_normalize_before = False self.subln = False diff --git a/torchscale/architecture/retnet.py b/torchscale/architecture/retnet.py index 83a368b..79e8d93 100644 --- a/torchscale/architecture/retnet.py +++ b/torchscale/architecture/retnet.py @@ -11,11 +11,15 @@ from fairscale.nn import checkpoint_wrapper, wrap from torchscale.architecture.utils import init_bert_params from torchscale.component.droppath import DropPath -from torchscale.component.feedforward_network import make_experts +from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork from torchscale.component.gate_linear_unit import GLU from torchscale.component.multiscale_retention import MultiScaleRetention from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.routing import Top1Gate, Top2Gate +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm from torchscale.component.rms_norm import RMSNorm @@ -92,7 +96,7 @@ class DecoderLayer(nn.Module): self.normalize_before = args.decoder_normalize_before - self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps) + self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps) self.is_moe_layer = is_moe_layer self.ffn_dim = args.decoder_ffn_embed_dim @@ -124,7 +128,7 @@ class DecoderLayer(nn.Module): experts = make_experts(args, self.embed_dim, self.ffn_dim) self.moe_layer = MOELayer(gate, experts, args) - self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps) + self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps) if args.deepnorm: self.alpha = math.pow(2.0 * args.decoder_layers, 0.25) @@ -138,6 +142,14 @@ class DecoderLayer(nn.Module): args.activation_fn, args.dropout, args.activation_dropout, + ) if args.use_glu else FeedForwardNetwork( + embed_dim, + self.ffn_dim, + args.activation_fn, + args.dropout, + args.activation_dropout, + args.layernorm_eps, + args.subln, ) def build_retention(self, embed_dim, args): @@ -225,7 +237,7 @@ class RetNetDecoder(nn.Module): self.output_projection = output_projection if args.layernorm_embedding: - self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps) + self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps) else: self.layernorm_embedding = None @@ -245,7 +257,7 @@ class RetNetDecoder(nn.Module): self.num_layers = len(self.layers) if args.decoder_normalize_before: - self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps) + self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps) else: self.layer_norm = None diff --git a/torchscale/component/multiscale_retention.py b/torchscale/component/multiscale_retention.py index c481b23..0c14226 100644 --- a/torchscale/component/multiscale_retention.py +++ b/torchscale/component/multiscale_retention.py @@ -5,6 +5,10 @@ import torch import torch.nn.functional as F from torch import nn +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm from .rms_norm import RMSNorm from .multiway_network import MultiwayWrapper @@ -56,14 +60,14 @@ class MultiScaleRetention(nn.Module): self.gate_fn = get_activation_fn(activation=str(gate_fn)) - self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) - self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) - self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) - self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) + self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases)) + self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases)) + self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases)) + self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases)) - self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False)) + self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases)) - self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) + self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) self.reset_parameters() def reset_parameters(self): @@ -72,6 +76,8 @@ class MultiScaleRetention(nn.Module): nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.out_proj.weight) + if hasattr(self.out_proj, "bias"): + nn.init.constant_(self.out_proj.bias, 0.0) def parallel_forward(self, qr, kr, v, mask): bsz, tgt_len, embed_dim = v.size()