added compat flags because I guess the maintainer assumed no one was actually using the retnet and thinks they can change things willy nilly
This commit is contained in:
parent
ce77afe916
commit
008f1b6d18
|
@ -261,6 +261,11 @@ class RetNetConfig(object):
|
||||||
# RetNet's RelPos base
|
# RetNet's RelPos base
|
||||||
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
|
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
|
||||||
|
|
||||||
|
# Backwards compatibility flags
|
||||||
|
self.use_layernorm = kwargs.pop("use_layernorm", False)
|
||||||
|
self.use_biases = kwargs.pop("use_biases", False)
|
||||||
|
self.use_glu = kwargs.pop("use_glu", True)
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.decoder_normalize_before = False
|
self.decoder_normalize_before = False
|
||||||
self.subln = False
|
self.subln = False
|
||||||
|
|
|
@ -11,11 +11,15 @@ from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
from torchscale.architecture.utils import init_bert_params
|
||||||
from torchscale.component.droppath import DropPath
|
from torchscale.component.droppath import DropPath
|
||||||
from torchscale.component.feedforward_network import make_experts
|
from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
|
||||||
from torchscale.component.gate_linear_unit import GLU
|
from torchscale.component.gate_linear_unit import GLU
|
||||||
from torchscale.component.multiscale_retention import MultiScaleRetention
|
from torchscale.component.multiscale_retention import MultiScaleRetention
|
||||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
try:
|
||||||
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from torch.nn import LayerNorm
|
||||||
from torchscale.component.rms_norm import RMSNorm
|
from torchscale.component.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +96,7 @@ class DecoderLayer(nn.Module):
|
||||||
|
|
||||||
self.normalize_before = args.decoder_normalize_before
|
self.normalize_before = args.decoder_normalize_before
|
||||||
|
|
||||||
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
||||||
|
|
||||||
self.is_moe_layer = is_moe_layer
|
self.is_moe_layer = is_moe_layer
|
||||||
self.ffn_dim = args.decoder_ffn_embed_dim
|
self.ffn_dim = args.decoder_ffn_embed_dim
|
||||||
|
@ -124,7 +128,7 @@ class DecoderLayer(nn.Module):
|
||||||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||||
self.moe_layer = MOELayer(gate, experts, args)
|
self.moe_layer = MOELayer(gate, experts, args)
|
||||||
|
|
||||||
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
||||||
|
|
||||||
if args.deepnorm:
|
if args.deepnorm:
|
||||||
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
||||||
|
@ -138,6 +142,14 @@ class DecoderLayer(nn.Module):
|
||||||
args.activation_fn,
|
args.activation_fn,
|
||||||
args.dropout,
|
args.dropout,
|
||||||
args.activation_dropout,
|
args.activation_dropout,
|
||||||
|
) if args.use_glu else FeedForwardNetwork(
|
||||||
|
embed_dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
args.activation_fn,
|
||||||
|
args.dropout,
|
||||||
|
args.activation_dropout,
|
||||||
|
args.layernorm_eps,
|
||||||
|
args.subln,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_retention(self, embed_dim, args):
|
def build_retention(self, embed_dim, args):
|
||||||
|
@ -225,7 +237,7 @@ class RetNetDecoder(nn.Module):
|
||||||
self.output_projection = output_projection
|
self.output_projection = output_projection
|
||||||
|
|
||||||
if args.layernorm_embedding:
|
if args.layernorm_embedding:
|
||||||
self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
||||||
else:
|
else:
|
||||||
self.layernorm_embedding = None
|
self.layernorm_embedding = None
|
||||||
|
|
||||||
|
@ -245,7 +257,7 @@ class RetNetDecoder(nn.Module):
|
||||||
self.num_layers = len(self.layers)
|
self.num_layers = len(self.layers)
|
||||||
|
|
||||||
if args.decoder_normalize_before:
|
if args.decoder_normalize_before:
|
||||||
self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
||||||
else:
|
else:
|
||||||
self.layer_norm = None
|
self.layer_norm = None
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,10 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
try:
|
||||||
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from torch.nn import LayerNorm
|
||||||
from .rms_norm import RMSNorm
|
from .rms_norm import RMSNorm
|
||||||
|
|
||||||
from .multiway_network import MultiwayWrapper
|
from .multiway_network import MultiwayWrapper
|
||||||
|
@ -56,14 +60,14 @@ class MultiScaleRetention(nn.Module):
|
||||||
|
|
||||||
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
||||||
|
|
||||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
|
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
|
||||||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
|
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
|
||||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
|
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
|
||||||
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
|
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
|
||||||
|
|
||||||
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
|
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases))
|
||||||
|
|
||||||
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
|
self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
|
@ -72,6 +76,8 @@ class MultiScaleRetention(nn.Module):
|
||||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
|
||||||
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
|
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
|
||||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
if hasattr(self.out_proj, "bias"):
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
|
||||||
def parallel_forward(self, qr, kr, v, mask):
|
def parallel_forward(self, qr, kr, v, mask):
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
bsz, tgt_len, embed_dim = v.size()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user