added compat flags because I guess the maintainer assumed no one was actually using the retnet and thinks they can change things willy nilly

This commit is contained in:
mrq 2023-10-05 16:38:57 -05:00
parent ce77afe916
commit 008f1b6d18
3 changed files with 34 additions and 11 deletions

View File

@ -261,6 +261,11 @@ class RetNetConfig(object):
# RetNet's RelPos base # RetNet's RelPos base
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000) self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
# Backwards compatibility flags
self.use_layernorm = kwargs.pop("use_layernorm", False)
self.use_biases = kwargs.pop("use_biases", False)
self.use_glu = kwargs.pop("use_glu", True)
if self.deepnorm: if self.deepnorm:
self.decoder_normalize_before = False self.decoder_normalize_before = False
self.subln = False self.subln = False

View File

@ -11,11 +11,15 @@ from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import make_experts from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
from torchscale.component.gate_linear_unit import GLU from torchscale.component.gate_linear_unit import GLU
from torchscale.component.multiscale_retention import MultiScaleRetention from torchscale.component.multiscale_retention import MultiScaleRetention
from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.component.rms_norm import RMSNorm from torchscale.component.rms_norm import RMSNorm
@ -92,7 +96,7 @@ class DecoderLayer(nn.Module):
self.normalize_before = args.decoder_normalize_before self.normalize_before = args.decoder_normalize_before
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps) self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
self.is_moe_layer = is_moe_layer self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim self.ffn_dim = args.decoder_ffn_embed_dim
@ -124,7 +128,7 @@ class DecoderLayer(nn.Module):
experts = make_experts(args, self.embed_dim, self.ffn_dim) experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args) self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps) self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
if args.deepnorm: if args.deepnorm:
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25) self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
@ -138,6 +142,14 @@ class DecoderLayer(nn.Module):
args.activation_fn, args.activation_fn,
args.dropout, args.dropout,
args.activation_dropout, args.activation_dropout,
) if args.use_glu else FeedForwardNetwork(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
) )
def build_retention(self, embed_dim, args): def build_retention(self, embed_dim, args):
@ -225,7 +237,7 @@ class RetNetDecoder(nn.Module):
self.output_projection = output_projection self.output_projection = output_projection
if args.layernorm_embedding: if args.layernorm_embedding:
self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps) self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else: else:
self.layernorm_embedding = None self.layernorm_embedding = None
@ -245,7 +257,7 @@ class RetNetDecoder(nn.Module):
self.num_layers = len(self.layers) self.num_layers = len(self.layers)
if args.decoder_normalize_before: if args.decoder_normalize_before:
self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps) self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else: else:
self.layer_norm = None self.layer_norm = None

View File

@ -5,6 +5,10 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .rms_norm import RMSNorm from .rms_norm import RMSNorm
from .multiway_network import MultiwayWrapper from .multiway_network import MultiwayWrapper
@ -56,14 +60,14 @@ class MultiScaleRetention(nn.Module):
self.gate_fn = get_activation_fn(activation=str(gate_fn)) self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False)) self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False)) self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False)) self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases))
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)) self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
@ -72,6 +76,8 @@ class MultiScaleRetention(nn.Module):
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.out_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight)
if hasattr(self.out_proj, "bias"):
nn.init.constant_(self.out_proj.bias, 0.0)
def parallel_forward(self, qr, kr, v, mask): def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.size() bsz, tgt_len, embed_dim = v.size()