added compat flags because I guess the maintainer assumed no one was actually using the retnet and thinks they can change things willy nilly

This commit is contained in:
mrq 2023-10-05 16:38:57 -05:00
parent ce77afe916
commit 008f1b6d18
3 changed files with 34 additions and 11 deletions

View File

@ -261,6 +261,11 @@ class RetNetConfig(object):
# RetNet's RelPos base
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
# Backwards compatibility flags
self.use_layernorm = kwargs.pop("use_layernorm", False)
self.use_biases = kwargs.pop("use_biases", False)
self.use_glu = kwargs.pop("use_glu", True)
if self.deepnorm:
self.decoder_normalize_before = False
self.subln = False

View File

@ -11,11 +11,15 @@ from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import make_experts
from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
from torchscale.component.gate_linear_unit import GLU
from torchscale.component.multiscale_retention import MultiScaleRetention
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.component.rms_norm import RMSNorm
@ -92,7 +96,7 @@ class DecoderLayer(nn.Module):
self.normalize_before = args.decoder_normalize_before
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim
@ -124,7 +128,7 @@ class DecoderLayer(nn.Module):
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
if args.deepnorm:
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
@ -138,6 +142,14 @@ class DecoderLayer(nn.Module):
args.activation_fn,
args.dropout,
args.activation_dropout,
) if args.use_glu else FeedForwardNetwork(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
def build_retention(self, embed_dim, args):
@ -225,7 +237,7 @@ class RetNetDecoder(nn.Module):
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps)
self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else:
self.layernorm_embedding = None
@ -245,7 +257,7 @@ class RetNetDecoder(nn.Module):
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps)
self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else:
self.layer_norm = None

View File

@ -5,6 +5,10 @@
import torch
import torch.nn.functional as F
from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .rms_norm import RMSNorm
from .multiway_network import MultiwayWrapper
@ -56,14 +60,14 @@ class MultiScaleRetention(nn.Module):
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases))
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.reset_parameters()
def reset_parameters(self):
@ -72,6 +76,8 @@ class MultiScaleRetention(nn.Module):
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.out_proj.weight)
if hasattr(self.out_proj, "bias"):
nn.init.constant_(self.out_proj.bias, 0.0)
def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.size()