added compat flags because I guess the maintainer assumed no one was actually using the retnet and thinks they can change things willy nilly
This commit is contained in:
parent
ce77afe916
commit
008f1b6d18
|
@ -261,6 +261,11 @@ class RetNetConfig(object):
|
|||
# RetNet's RelPos base
|
||||
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
|
||||
|
||||
# Backwards compatibility flags
|
||||
self.use_layernorm = kwargs.pop("use_layernorm", False)
|
||||
self.use_biases = kwargs.pop("use_biases", False)
|
||||
self.use_glu = kwargs.pop("use_glu", True)
|
||||
|
||||
if self.deepnorm:
|
||||
self.decoder_normalize_before = False
|
||||
self.subln = False
|
||||
|
|
|
@ -11,11 +11,15 @@ from fairscale.nn import checkpoint_wrapper, wrap
|
|||
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.droppath import DropPath
|
||||
from torchscale.component.feedforward_network import make_experts
|
||||
from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
|
||||
from torchscale.component.gate_linear_unit import GLU
|
||||
from torchscale.component.multiscale_retention import MultiScaleRetention
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
except ModuleNotFoundError:
|
||||
from torch.nn import LayerNorm
|
||||
from torchscale.component.rms_norm import RMSNorm
|
||||
|
||||
|
||||
|
@ -92,7 +96,7 @@ class DecoderLayer(nn.Module):
|
|||
|
||||
self.normalize_before = args.decoder_normalize_before
|
||||
|
||||
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||
self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
||||
|
||||
self.is_moe_layer = is_moe_layer
|
||||
self.ffn_dim = args.decoder_ffn_embed_dim
|
||||
|
@ -124,7 +128,7 @@ class DecoderLayer(nn.Module):
|
|||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||
self.moe_layer = MOELayer(gate, experts, args)
|
||||
|
||||
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||
self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
||||
|
||||
if args.deepnorm:
|
||||
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
||||
|
@ -138,6 +142,14 @@ class DecoderLayer(nn.Module):
|
|||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
) if args.use_glu else FeedForwardNetwork(
|
||||
embed_dim,
|
||||
self.ffn_dim,
|
||||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.layernorm_eps,
|
||||
args.subln,
|
||||
)
|
||||
|
||||
def build_retention(self, embed_dim, args):
|
||||
|
@ -225,7 +237,7 @@ class RetNetDecoder(nn.Module):
|
|||
self.output_projection = output_projection
|
||||
|
||||
if args.layernorm_embedding:
|
||||
self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
||||
self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
|
@ -245,7 +257,7 @@ class RetNetDecoder(nn.Module):
|
|||
self.num_layers = len(self.layers)
|
||||
|
||||
if args.decoder_normalize_before:
|
||||
self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
||||
self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
|
|
|
@ -5,6 +5,10 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
except ModuleNotFoundError:
|
||||
from torch.nn import LayerNorm
|
||||
from .rms_norm import RMSNorm
|
||||
|
||||
from .multiway_network import MultiwayWrapper
|
||||
|
@ -56,14 +60,14 @@ class MultiScaleRetention(nn.Module):
|
|||
|
||||
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
||||
|
||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
|
||||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
|
||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
|
||||
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
|
||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
|
||||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
|
||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
|
||||
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
|
||||
|
||||
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
|
||||
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases))
|
||||
|
||||
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
|
||||
self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
|
@ -72,6 +76,8 @@ class MultiScaleRetention(nn.Module):
|
|||
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
|
||||
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if hasattr(self.out_proj, "bias"):
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
def parallel_forward(self, qr, kr, v, mask):
|
||||
bsz, tgt_len, embed_dim = v.size()
|
||||
|
|
Loading…
Reference in New Issue
Block a user