|
|
|
@ -11,11 +11,15 @@ from fairscale.nn import checkpoint_wrapper, wrap
|
|
|
|
|
|
|
|
|
|
from torchscale.architecture.utils import init_bert_params
|
|
|
|
|
from torchscale.component.droppath import DropPath
|
|
|
|
|
from torchscale.component.feedforward_network import make_experts
|
|
|
|
|
from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
|
|
|
|
|
from torchscale.component.gate_linear_unit import GLU
|
|
|
|
|
from torchscale.component.multiscale_retention import MultiScaleRetention
|
|
|
|
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
|
|
|
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
|
|
|
|
try:
|
|
|
|
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
|
from torch.nn import LayerNorm
|
|
|
|
|
from torchscale.component.rms_norm import RMSNorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -92,7 +96,7 @@ class DecoderLayer(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.normalize_before = args.decoder_normalize_before
|
|
|
|
|
|
|
|
|
|
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
|
|
|
|
|
self.is_moe_layer = is_moe_layer
|
|
|
|
|
self.ffn_dim = args.decoder_ffn_embed_dim
|
|
|
|
@ -124,7 +128,7 @@ class DecoderLayer(nn.Module):
|
|
|
|
|
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
|
|
|
|
self.moe_layer = MOELayer(gate, experts, args)
|
|
|
|
|
|
|
|
|
|
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
|
|
|
|
|
if args.deepnorm:
|
|
|
|
|
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
|
|
|
@ -138,6 +142,14 @@ class DecoderLayer(nn.Module):
|
|
|
|
|
args.activation_fn,
|
|
|
|
|
args.dropout,
|
|
|
|
|
args.activation_dropout,
|
|
|
|
|
) if args.use_glu else FeedForwardNetwork(
|
|
|
|
|
embed_dim,
|
|
|
|
|
self.ffn_dim,
|
|
|
|
|
args.activation_fn,
|
|
|
|
|
args.dropout,
|
|
|
|
|
args.activation_dropout,
|
|
|
|
|
args.layernorm_eps,
|
|
|
|
|
args.subln,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def build_retention(self, embed_dim, args):
|
|
|
|
@ -225,7 +237,7 @@ class RetNetDecoder(nn.Module):
|
|
|
|
|
self.output_projection = output_projection
|
|
|
|
|
|
|
|
|
|
if args.layernorm_embedding:
|
|
|
|
|
self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
else:
|
|
|
|
|
self.layernorm_embedding = None
|
|
|
|
|
|
|
|
|
@ -245,7 +257,7 @@ class RetNetDecoder(nn.Module):
|
|
|
|
|
self.num_layers = len(self.layers)
|
|
|
|
|
|
|
|
|
|
if args.decoder_normalize_before:
|
|
|
|
|
self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
|
|
|
|
else:
|
|
|
|
|
self.layer_norm = None
|
|
|
|
|
|
|
|
|
|